File size: 4,030 Bytes
d1033d4
 
 
 
 
 
 
 
 
 
 
 
 
f94169f
 
 
 
 
 
 
 
 
 
 
d1033d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f94169f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1033d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
モデルレジストリ

開放閉鎖原則(OCP)に準拠し、新モデル追加時に
既存コードの変更を不要にする
"""
from typing import Dict, List, Optional, Type

from .base import BaseLanguageModel, ModelConfig
from .gpt2 import GPT2Model, GPT2_SMALL_CONFIG, GPT2_MEDIUM_CONFIG
from .gpt_neo import GPTNeoModel, GPT_NEO_125M_CONFIG
from .opt import OPTModel, OPT_125M_CONFIG

# Phase 1: GPT-OSS and Fully Open Source Models
from .gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG
from .pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG
from .olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG
from .bloom import BLOOMModel, BLOOM_560M_CONFIG

# Phase 2: Latest Architecture Models
from .llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG
from .qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG
from .mistral import MistralModel, MISTRAL_7B_CONFIG


class ModelRegistry:
    """
    モデルレジストリ

    利用可能なモデルを管理し、キーに基づいて
    適切なモデルインスタンスを提供する
    """

    _registry: Dict[str, tuple[Type[BaseLanguageModel], ModelConfig]] = {}

    @classmethod
    def register(
        cls,
        key: str,
        model_class: Type[BaseLanguageModel],
        config: ModelConfig,
    ) -> None:
        """
        新しいモデルをレジストリに登録

        Args:
            key: モデルを識別するキー
            model_class: モデルクラス
            config: モデル設定
        """
        cls._registry[key] = (model_class, config)

    @classmethod
    def get(cls, key: str) -> BaseLanguageModel:
        """
        キーに対応するモデルインスタンスを取得

        Args:
            key: モデルを識別するキー

        Returns:
            モデルインスタンス

        Raises:
            KeyError: 指定されたキーが存在しない場合
        """
        if key not in cls._registry:
            available = ", ".join(cls._registry.keys())
            raise KeyError(f"Model '{key}' not found. Available: {available}")

        model_class, config = cls._registry[key]
        return model_class(config)

    @classmethod
    def list_models(cls) -> List[str]:
        """登録済みモデルのキー一覧を取得"""
        return list(cls._registry.keys())

    @classmethod
    def get_config(cls, key: str) -> Optional[ModelConfig]:
        """指定キーのモデル設定を取得"""
        if key not in cls._registry:
            return None
        return cls._registry[key][1]

    @classmethod
    def get_all_configs(cls) -> Dict[str, ModelConfig]:
        """すべてのモデル設定を取得"""
        return {key: config for key, (_, config) in cls._registry.items()}


# デフォルトモデルの登録
ModelRegistry.register("gpt2", GPT2Model, GPT2_SMALL_CONFIG)
ModelRegistry.register("gpt2-medium", GPT2Model, GPT2_MEDIUM_CONFIG)
ModelRegistry.register("gpt-neo-125m", GPTNeoModel, GPT_NEO_125M_CONFIG)
ModelRegistry.register("opt-125m", OPTModel, OPT_125M_CONFIG)

# Phase 1: GPT-OSS and Fully Open Source Models
ModelRegistry.register("gpt-oss-20b", GPTOSSModel, GPT_OSS_20B_CONFIG)
ModelRegistry.register("pythia-410m", PythiaModel, PYTHIA_410M_CONFIG)
ModelRegistry.register("pythia-1b", PythiaModel, PYTHIA_1B_CONFIG)
ModelRegistry.register("olmo-1b", OLMoModel, OLMO_1B_CONFIG)
ModelRegistry.register("olmo-7b", OLMoModel, OLMO_7B_CONFIG)
ModelRegistry.register("bloom-560m", BLOOMModel, BLOOM_560M_CONFIG)

# Phase 2: Latest Architecture Models
ModelRegistry.register("llama-3.2-1b", LlamaModel, LLAMA_3_2_1B_CONFIG)
ModelRegistry.register("llama-3.2-3b", LlamaModel, LLAMA_3_2_3B_CONFIG)
ModelRegistry.register("qwen2.5-0.5b", QwenModel, QWEN_2_5_0_5B_CONFIG)
ModelRegistry.register("qwen2.5-1.5b", QwenModel, QWEN_2_5_1_5B_CONFIG)
ModelRegistry.register("mistral-7b", MistralModel, MISTRAL_7B_CONFIG)

# デフォルトモデルキー
DEFAULT_MODEL_KEY = "gpt2"