will / src /models /registry.py
matt1847's picture
機能追加: モデルラインナップ拡張とGradio UI移行
f94169f
"""
モデルレジストリ
開放閉鎖原則(OCP)に準拠し、新モデル追加時に
既存コードの変更を不要にする
"""
from typing import Dict, List, Optional, Type
from .base import BaseLanguageModel, ModelConfig
from .gpt2 import GPT2Model, GPT2_SMALL_CONFIG, GPT2_MEDIUM_CONFIG
from .gpt_neo import GPTNeoModel, GPT_NEO_125M_CONFIG
from .opt import OPTModel, OPT_125M_CONFIG
# Phase 1: GPT-OSS and Fully Open Source Models
from .gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG
from .pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG
from .olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG
from .bloom import BLOOMModel, BLOOM_560M_CONFIG
# Phase 2: Latest Architecture Models
from .llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG
from .qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG
from .mistral import MistralModel, MISTRAL_7B_CONFIG
class ModelRegistry:
"""
モデルレジストリ
利用可能なモデルを管理し、キーに基づいて
適切なモデルインスタンスを提供する
"""
_registry: Dict[str, tuple[Type[BaseLanguageModel], ModelConfig]] = {}
@classmethod
def register(
cls,
key: str,
model_class: Type[BaseLanguageModel],
config: ModelConfig,
) -> None:
"""
新しいモデルをレジストリに登録
Args:
key: モデルを識別するキー
model_class: モデルクラス
config: モデル設定
"""
cls._registry[key] = (model_class, config)
@classmethod
def get(cls, key: str) -> BaseLanguageModel:
"""
キーに対応するモデルインスタンスを取得
Args:
key: モデルを識別するキー
Returns:
モデルインスタンス
Raises:
KeyError: 指定されたキーが存在しない場合
"""
if key not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"Model '{key}' not found. Available: {available}")
model_class, config = cls._registry[key]
return model_class(config)
@classmethod
def list_models(cls) -> List[str]:
"""登録済みモデルのキー一覧を取得"""
return list(cls._registry.keys())
@classmethod
def get_config(cls, key: str) -> Optional[ModelConfig]:
"""指定キーのモデル設定を取得"""
if key not in cls._registry:
return None
return cls._registry[key][1]
@classmethod
def get_all_configs(cls) -> Dict[str, ModelConfig]:
"""すべてのモデル設定を取得"""
return {key: config for key, (_, config) in cls._registry.items()}
# デフォルトモデルの登録
ModelRegistry.register("gpt2", GPT2Model, GPT2_SMALL_CONFIG)
ModelRegistry.register("gpt2-medium", GPT2Model, GPT2_MEDIUM_CONFIG)
ModelRegistry.register("gpt-neo-125m", GPTNeoModel, GPT_NEO_125M_CONFIG)
ModelRegistry.register("opt-125m", OPTModel, OPT_125M_CONFIG)
# Phase 1: GPT-OSS and Fully Open Source Models
ModelRegistry.register("gpt-oss-20b", GPTOSSModel, GPT_OSS_20B_CONFIG)
ModelRegistry.register("pythia-410m", PythiaModel, PYTHIA_410M_CONFIG)
ModelRegistry.register("pythia-1b", PythiaModel, PYTHIA_1B_CONFIG)
ModelRegistry.register("olmo-1b", OLMoModel, OLMO_1B_CONFIG)
ModelRegistry.register("olmo-7b", OLMoModel, OLMO_7B_CONFIG)
ModelRegistry.register("bloom-560m", BLOOMModel, BLOOM_560M_CONFIG)
# Phase 2: Latest Architecture Models
ModelRegistry.register("llama-3.2-1b", LlamaModel, LLAMA_3_2_1B_CONFIG)
ModelRegistry.register("llama-3.2-3b", LlamaModel, LLAMA_3_2_3B_CONFIG)
ModelRegistry.register("qwen2.5-0.5b", QwenModel, QWEN_2_5_0_5B_CONFIG)
ModelRegistry.register("qwen2.5-1.5b", QwenModel, QWEN_2_5_1_5B_CONFIG)
ModelRegistry.register("mistral-7b", MistralModel, MISTRAL_7B_CONFIG)
# デフォルトモデルキー
DEFAULT_MODEL_KEY = "gpt2"