|
|
""" |
|
|
モデルレジストリ |
|
|
|
|
|
開放閉鎖原則(OCP)に準拠し、新モデル追加時に |
|
|
既存コードの変更を不要にする |
|
|
""" |
|
|
from typing import Dict, List, Optional, Type |
|
|
|
|
|
from .base import BaseLanguageModel, ModelConfig |
|
|
from .gpt2 import GPT2Model, GPT2_SMALL_CONFIG, GPT2_MEDIUM_CONFIG |
|
|
from .gpt_neo import GPTNeoModel, GPT_NEO_125M_CONFIG |
|
|
from .opt import OPTModel, OPT_125M_CONFIG |
|
|
|
|
|
|
|
|
from .gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG |
|
|
from .pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG |
|
|
from .olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG |
|
|
from .bloom import BLOOMModel, BLOOM_560M_CONFIG |
|
|
|
|
|
|
|
|
from .llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG |
|
|
from .qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG |
|
|
from .mistral import MistralModel, MISTRAL_7B_CONFIG |
|
|
|
|
|
|
|
|
class ModelRegistry: |
|
|
""" |
|
|
モデルレジストリ |
|
|
|
|
|
利用可能なモデルを管理し、キーに基づいて |
|
|
適切なモデルインスタンスを提供する |
|
|
""" |
|
|
|
|
|
_registry: Dict[str, tuple[Type[BaseLanguageModel], ModelConfig]] = {} |
|
|
|
|
|
@classmethod |
|
|
def register( |
|
|
cls, |
|
|
key: str, |
|
|
model_class: Type[BaseLanguageModel], |
|
|
config: ModelConfig, |
|
|
) -> None: |
|
|
""" |
|
|
新しいモデルをレジストリに登録 |
|
|
|
|
|
Args: |
|
|
key: モデルを識別するキー |
|
|
model_class: モデルクラス |
|
|
config: モデル設定 |
|
|
""" |
|
|
cls._registry[key] = (model_class, config) |
|
|
|
|
|
@classmethod |
|
|
def get(cls, key: str) -> BaseLanguageModel: |
|
|
""" |
|
|
キーに対応するモデルインスタンスを取得 |
|
|
|
|
|
Args: |
|
|
key: モデルを識別するキー |
|
|
|
|
|
Returns: |
|
|
モデルインスタンス |
|
|
|
|
|
Raises: |
|
|
KeyError: 指定されたキーが存在しない場合 |
|
|
""" |
|
|
if key not in cls._registry: |
|
|
available = ", ".join(cls._registry.keys()) |
|
|
raise KeyError(f"Model '{key}' not found. Available: {available}") |
|
|
|
|
|
model_class, config = cls._registry[key] |
|
|
return model_class(config) |
|
|
|
|
|
@classmethod |
|
|
def list_models(cls) -> List[str]: |
|
|
"""登録済みモデルのキー一覧を取得""" |
|
|
return list(cls._registry.keys()) |
|
|
|
|
|
@classmethod |
|
|
def get_config(cls, key: str) -> Optional[ModelConfig]: |
|
|
"""指定キーのモデル設定を取得""" |
|
|
if key not in cls._registry: |
|
|
return None |
|
|
return cls._registry[key][1] |
|
|
|
|
|
@classmethod |
|
|
def get_all_configs(cls) -> Dict[str, ModelConfig]: |
|
|
"""すべてのモデル設定を取得""" |
|
|
return {key: config for key, (_, config) in cls._registry.items()} |
|
|
|
|
|
|
|
|
|
|
|
ModelRegistry.register("gpt2", GPT2Model, GPT2_SMALL_CONFIG) |
|
|
ModelRegistry.register("gpt2-medium", GPT2Model, GPT2_MEDIUM_CONFIG) |
|
|
ModelRegistry.register("gpt-neo-125m", GPTNeoModel, GPT_NEO_125M_CONFIG) |
|
|
ModelRegistry.register("opt-125m", OPTModel, OPT_125M_CONFIG) |
|
|
|
|
|
|
|
|
ModelRegistry.register("gpt-oss-20b", GPTOSSModel, GPT_OSS_20B_CONFIG) |
|
|
ModelRegistry.register("pythia-410m", PythiaModel, PYTHIA_410M_CONFIG) |
|
|
ModelRegistry.register("pythia-1b", PythiaModel, PYTHIA_1B_CONFIG) |
|
|
ModelRegistry.register("olmo-1b", OLMoModel, OLMO_1B_CONFIG) |
|
|
ModelRegistry.register("olmo-7b", OLMoModel, OLMO_7B_CONFIG) |
|
|
ModelRegistry.register("bloom-560m", BLOOMModel, BLOOM_560M_CONFIG) |
|
|
|
|
|
|
|
|
ModelRegistry.register("llama-3.2-1b", LlamaModel, LLAMA_3_2_1B_CONFIG) |
|
|
ModelRegistry.register("llama-3.2-3b", LlamaModel, LLAMA_3_2_3B_CONFIG) |
|
|
ModelRegistry.register("qwen2.5-0.5b", QwenModel, QWEN_2_5_0_5B_CONFIG) |
|
|
ModelRegistry.register("qwen2.5-1.5b", QwenModel, QWEN_2_5_1_5B_CONFIG) |
|
|
ModelRegistry.register("mistral-7b", MistralModel, MISTRAL_7B_CONFIG) |
|
|
|
|
|
|
|
|
DEFAULT_MODEL_KEY = "gpt2" |
|
|
|