""" モデル関連のテスト """ import pytest import torch from src.models.base import ModelConfig, BaseLanguageModel from src.models.registry import ModelRegistry, DEFAULT_MODEL_KEY from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG # Phase 1: GPT-OSS and Fully Open Source Models from src.models.gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG from src.models.pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG from src.models.olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG from src.models.bloom import BLOOMModel, BLOOM_560M_CONFIG # Phase 2: Latest Architecture Models from src.models.llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG from src.models.qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG from src.models.mistral import MistralModel, MISTRAL_7B_CONFIG class TestModelConfig: """ModelConfigのテスト""" def test_config_is_immutable(self): """設定がイミュータブルであることを確認""" config = ModelConfig( name="Test", model_id="test", embedding_dim=768, vocab_size=50000, ) with pytest.raises(Exception): config.name = "Changed" def test_config_attributes(self): """設定属性が正しく保持されることを確認""" config = ModelConfig( name="Test Model", model_id="test-model", embedding_dim=1024, vocab_size=30000, ) assert config.name == "Test Model" assert config.model_id == "test-model" assert config.embedding_dim == 1024 assert config.vocab_size == 30000 class TestModelRegistry: """ModelRegistryのテスト""" def test_list_models(self): """登録済みモデル一覧が取得できることを確認""" models = ModelRegistry.list_models() assert len(models) > 0 assert DEFAULT_MODEL_KEY in models def test_get_model(self): """モデルインスタンスが取得できることを確認""" model = ModelRegistry.get(DEFAULT_MODEL_KEY) assert isinstance(model, BaseLanguageModel) def test_get_nonexistent_model(self): """存在しないモデルでKeyErrorが発生することを確認""" with pytest.raises(KeyError): ModelRegistry.get("nonexistent-model") def test_get_config(self): """モデル設定が取得できることを確認""" config = ModelRegistry.get_config(DEFAULT_MODEL_KEY) assert config is not None assert isinstance(config, ModelConfig) def test_get_all_configs(self): """すべてのモデル設定が取得できることを確認""" configs = ModelRegistry.get_all_configs() assert len(configs) > 0 for key, config in configs.items(): assert isinstance(config, ModelConfig) class TestGPT2Model: """GPT2Modelのテスト""" def test_config(self): """設定が正しいことを確認""" model = GPT2Model(GPT2_SMALL_CONFIG) assert model.config == GPT2_SMALL_CONFIG assert model.config.embedding_dim == 768 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = GPT2Model(GPT2_SMALL_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = GPT2Model(GPT2_SMALL_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 768) @pytest.mark.slow class TestGPT2ModelIntegration: """GPT2Modelの統合テスト(モデルロードが必要)""" @pytest.fixture def loaded_model(self): """ロード済みモデルを提供""" model = GPT2Model(GPT2_SMALL_CONFIG) model.load() return model def test_load(self, loaded_model): """モデルがロードできることを確認""" assert loaded_model.is_loaded def test_forward_with_noise(self, loaded_model): """順伝播が正しい形状を返すことを確認""" noise = loaded_model.generate_noise(seq_len=8) logits, corrupted_logits = loaded_model.forward_with_noise(noise) assert logits.shape[0] == 1 assert logits.shape[1] == 8 assert logits.shape[2] == loaded_model.config.vocab_size def test_decode_indices(self, loaded_model): """デコードが文字列リストを返すことを確認""" indices = [100, 200, 300] decoded = loaded_model.decode_indices(indices) assert len(decoded) == 3 assert all(isinstance(s, str) for s in decoded) # ============================================================================= # Phase 1: GPT-OSS and Fully Open Source Models # ============================================================================= class TestGPTOSSModel: """GPTOSSModelのテスト""" def test_config(self): """設定が正しいことを確認""" model = GPTOSSModel(GPT_OSS_20B_CONFIG) assert model.config == GPT_OSS_20B_CONFIG assert model.config.embedding_dim == 4096 assert model.config.vocab_size == 128000 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = GPTOSSModel(GPT_OSS_20B_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = GPTOSSModel(GPT_OSS_20B_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 4096) class TestPythiaModel: """PythiaModelのテスト""" def test_config_410m(self): """Pythia 410M設定が正しいことを確認""" model = PythiaModel(PYTHIA_410M_CONFIG) assert model.config == PYTHIA_410M_CONFIG assert model.config.embedding_dim == 1024 assert model.config.vocab_size == 50304 def test_config_1b(self): """Pythia 1B設定が正しいことを確認""" model = PythiaModel(PYTHIA_1B_CONFIG) assert model.config == PYTHIA_1B_CONFIG assert model.config.embedding_dim == 2048 assert model.config.vocab_size == 50304 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = PythiaModel(PYTHIA_410M_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = PythiaModel(PYTHIA_410M_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 1024) class TestOLMoModel: """OLMoModelのテスト""" def test_config_1b(self): """OLMo 1B設定が正しいことを確認""" model = OLMoModel(OLMO_1B_CONFIG) assert model.config == OLMO_1B_CONFIG assert model.config.embedding_dim == 2048 assert model.config.vocab_size == 50304 def test_config_7b(self): """OLMo 7B設定が正しいことを確認""" model = OLMoModel(OLMO_7B_CONFIG) assert model.config == OLMO_7B_CONFIG assert model.config.embedding_dim == 4096 assert model.config.vocab_size == 50304 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = OLMoModel(OLMO_1B_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = OLMoModel(OLMO_1B_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 2048) class TestBLOOMModel: """BLOOMModelのテスト""" def test_config(self): """BLOOM 560M設定が正しいことを確認""" model = BLOOMModel(BLOOM_560M_CONFIG) assert model.config == BLOOM_560M_CONFIG assert model.config.embedding_dim == 1024 assert model.config.vocab_size == 250880 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = BLOOMModel(BLOOM_560M_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = BLOOMModel(BLOOM_560M_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 1024) # ============================================================================= # Phase 2: Latest Architecture Models # ============================================================================= class TestLlamaModel: """LlamaModelのテスト""" def test_config_1b(self): """Llama 3.2 1B設定が正しいことを確認""" model = LlamaModel(LLAMA_3_2_1B_CONFIG) assert model.config == LLAMA_3_2_1B_CONFIG assert model.config.embedding_dim == 2048 assert model.config.vocab_size == 128256 def test_config_3b(self): """Llama 3.2 3B設定が正しいことを確認""" model = LlamaModel(LLAMA_3_2_3B_CONFIG) assert model.config == LLAMA_3_2_3B_CONFIG assert model.config.embedding_dim == 3072 assert model.config.vocab_size == 128256 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = LlamaModel(LLAMA_3_2_1B_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = LlamaModel(LLAMA_3_2_1B_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 2048) class TestQwenModel: """QwenModelのテスト""" def test_config_0_5b(self): """Qwen2.5 0.5B設定が正しいことを確認""" model = QwenModel(QWEN_2_5_0_5B_CONFIG) assert model.config == QWEN_2_5_0_5B_CONFIG assert model.config.embedding_dim == 896 assert model.config.vocab_size == 151936 def test_config_1_5b(self): """Qwen2.5 1.5B設定が正しいことを確認""" model = QwenModel(QWEN_2_5_1_5B_CONFIG) assert model.config == QWEN_2_5_1_5B_CONFIG assert model.config.embedding_dim == 1536 assert model.config.vocab_size == 151936 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = QwenModel(QWEN_2_5_0_5B_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = QwenModel(QWEN_2_5_0_5B_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 896) class TestMistralModel: """MistralModelのテスト""" def test_config(self): """Mistral 7B設定が正しいことを確認""" model = MistralModel(MISTRAL_7B_CONFIG) assert model.config == MISTRAL_7B_CONFIG assert model.config.embedding_dim == 4096 assert model.config.vocab_size == 32768 def test_is_loaded_initial(self): """初期状態ではロードされていないことを確認""" model = MistralModel(MISTRAL_7B_CONFIG) assert not model.is_loaded def test_generate_noise(self): """ノイズ生成が正しい形状であることを確認""" model = MistralModel(MISTRAL_7B_CONFIG) noise = model.generate_noise(seq_len=16, batch_size=2) assert noise.shape == (2, 16, 4096) # ============================================================================= # Registry Tests for New Models # ============================================================================= class TestModelRegistryNewModels: """新規追加モデルのレジストリテスト""" @pytest.mark.parametrize("model_key", [ "gpt-oss-20b", "pythia-410m", "pythia-1b", "olmo-1b", "olmo-7b", "bloom-560m", "llama-3.2-1b", "llama-3.2-3b", "qwen2.5-0.5b", "qwen2.5-1.5b", "mistral-7b", ]) def test_model_registered(self, model_key): """新モデルがレジストリに登録されていることを確認""" models = ModelRegistry.list_models() assert model_key in models @pytest.mark.parametrize("model_key", [ "gpt-oss-20b", "pythia-410m", "pythia-1b", "olmo-1b", "olmo-7b", "bloom-560m", "llama-3.2-1b", "llama-3.2-3b", "qwen2.5-0.5b", "qwen2.5-1.5b", "mistral-7b", ]) def test_model_instance_creation(self, model_key): """新モデルのインスタンスが作成できることを確認""" model = ModelRegistry.get(model_key) assert isinstance(model, BaseLanguageModel) assert not model.is_loaded # ============================================================================= # Integration Tests (require model download) # ============================================================================= @pytest.mark.slow class TestPythiaModelIntegration: """Pythiaモデルの統合テスト(小さいモデルで代表テスト)""" @pytest.fixture def loaded_model(self): """ロード済みモデルを提供""" model = PythiaModel(PYTHIA_410M_CONFIG) model.load() return model def test_load(self, loaded_model): """モデルがロードできることを確認""" assert loaded_model.is_loaded def test_forward_with_noise(self, loaded_model): """順伝播が正しい形状を返すことを確認""" noise = loaded_model.generate_noise(seq_len=8) logits, corrupted_logits = loaded_model.forward_with_noise(noise) assert logits.shape[0] == 1 assert logits.shape[1] == 8 assert logits.shape[2] == loaded_model.config.vocab_size def test_decode_indices(self, loaded_model): """デコードが文字列リストを返すことを確認""" indices = [100, 200, 300] decoded = loaded_model.decode_indices(indices) assert len(decoded) == 3 assert all(isinstance(s, str) for s in decoded) @pytest.mark.slow class TestBLOOMModelIntegration: """BLOOMモデルの統合テスト""" @pytest.fixture def loaded_model(self): """ロード済みモデルを提供""" model = BLOOMModel(BLOOM_560M_CONFIG) model.load() return model def test_load(self, loaded_model): """モデルがロードできることを確認""" assert loaded_model.is_loaded def test_forward_with_noise(self, loaded_model): """順伝播が正しい形状を返すことを確認""" noise = loaded_model.generate_noise(seq_len=8) logits, corrupted_logits = loaded_model.forward_with_noise(noise) assert logits.shape[0] == 1 assert logits.shape[1] == 8 assert logits.shape[2] == loaded_model.config.vocab_size @pytest.mark.slow class TestQwenModelIntegration: """Qwenモデルの統合テスト(小さいモデルで代表テスト)""" @pytest.fixture def loaded_model(self): """ロード済みモデルを提供""" model = QwenModel(QWEN_2_5_0_5B_CONFIG) model.load() return model def test_load(self, loaded_model): """モデルがロードできることを確認""" assert loaded_model.is_loaded def test_forward_with_noise(self, loaded_model): """順伝播が正しい形状を返すことを確認""" noise = loaded_model.generate_noise(seq_len=8) logits, corrupted_logits = loaded_model.forward_with_noise(noise) assert logits.shape[0] == 1 assert logits.shape[1] == 8 assert logits.shape[2] == loaded_model.config.vocab_size