""" モデルレジストリ 開放閉鎖原則(OCP)に準拠し、新モデル追加時に 既存コードの変更を不要にする """ from typing import Dict, List, Optional, Type from .base import BaseLanguageModel, ModelConfig from .gpt2 import GPT2Model, GPT2_SMALL_CONFIG, GPT2_MEDIUM_CONFIG from .gpt_neo import GPTNeoModel, GPT_NEO_125M_CONFIG from .opt import OPTModel, OPT_125M_CONFIG # Phase 1: GPT-OSS and Fully Open Source Models from .gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG from .pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG from .olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG from .bloom import BLOOMModel, BLOOM_560M_CONFIG # Phase 2: Latest Architecture Models from .llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG from .qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG from .mistral import MistralModel, MISTRAL_7B_CONFIG class ModelRegistry: """ モデルレジストリ 利用可能なモデルを管理し、キーに基づいて 適切なモデルインスタンスを提供する """ _registry: Dict[str, tuple[Type[BaseLanguageModel], ModelConfig]] = {} @classmethod def register( cls, key: str, model_class: Type[BaseLanguageModel], config: ModelConfig, ) -> None: """ 新しいモデルをレジストリに登録 Args: key: モデルを識別するキー model_class: モデルクラス config: モデル設定 """ cls._registry[key] = (model_class, config) @classmethod def get(cls, key: str) -> BaseLanguageModel: """ キーに対応するモデルインスタンスを取得 Args: key: モデルを識別するキー Returns: モデルインスタンス Raises: KeyError: 指定されたキーが存在しない場合 """ if key not in cls._registry: available = ", ".join(cls._registry.keys()) raise KeyError(f"Model '{key}' not found. Available: {available}") model_class, config = cls._registry[key] return model_class(config) @classmethod def list_models(cls) -> List[str]: """登録済みモデルのキー一覧を取得""" return list(cls._registry.keys()) @classmethod def get_config(cls, key: str) -> Optional[ModelConfig]: """指定キーのモデル設定を取得""" if key not in cls._registry: return None return cls._registry[key][1] @classmethod def get_all_configs(cls) -> Dict[str, ModelConfig]: """すべてのモデル設定を取得""" return {key: config for key, (_, config) in cls._registry.items()} # デフォルトモデルの登録 ModelRegistry.register("gpt2", GPT2Model, GPT2_SMALL_CONFIG) ModelRegistry.register("gpt2-medium", GPT2Model, GPT2_MEDIUM_CONFIG) ModelRegistry.register("gpt-neo-125m", GPTNeoModel, GPT_NEO_125M_CONFIG) ModelRegistry.register("opt-125m", OPTModel, OPT_125M_CONFIG) # Phase 1: GPT-OSS and Fully Open Source Models ModelRegistry.register("gpt-oss-20b", GPTOSSModel, GPT_OSS_20B_CONFIG) ModelRegistry.register("pythia-410m", PythiaModel, PYTHIA_410M_CONFIG) ModelRegistry.register("pythia-1b", PythiaModel, PYTHIA_1B_CONFIG) ModelRegistry.register("olmo-1b", OLMoModel, OLMO_1B_CONFIG) ModelRegistry.register("olmo-7b", OLMoModel, OLMO_7B_CONFIG) ModelRegistry.register("bloom-560m", BLOOMModel, BLOOM_560M_CONFIG) # Phase 2: Latest Architecture Models ModelRegistry.register("llama-3.2-1b", LlamaModel, LLAMA_3_2_1B_CONFIG) ModelRegistry.register("llama-3.2-3b", LlamaModel, LLAMA_3_2_3B_CONFIG) ModelRegistry.register("qwen2.5-0.5b", QwenModel, QWEN_2_5_0_5B_CONFIG) ModelRegistry.register("qwen2.5-1.5b", QwenModel, QWEN_2_5_1_5B_CONFIG) ModelRegistry.register("mistral-7b", MistralModel, MISTRAL_7B_CONFIG) # デフォルトモデルキー DEFAULT_MODEL_KEY = "gpt2"