|
|
""" |
|
|
モデル関連のテスト |
|
|
""" |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from src.models.base import ModelConfig, BaseLanguageModel |
|
|
from src.models.registry import ModelRegistry, DEFAULT_MODEL_KEY |
|
|
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG |
|
|
|
|
|
|
|
|
from src.models.gpt_oss import GPTOSSModel, GPT_OSS_20B_CONFIG |
|
|
from src.models.pythia import PythiaModel, PYTHIA_410M_CONFIG, PYTHIA_1B_CONFIG |
|
|
from src.models.olmo import OLMoModel, OLMO_1B_CONFIG, OLMO_7B_CONFIG |
|
|
from src.models.bloom import BLOOMModel, BLOOM_560M_CONFIG |
|
|
|
|
|
|
|
|
from src.models.llama import LlamaModel, LLAMA_3_2_1B_CONFIG, LLAMA_3_2_3B_CONFIG |
|
|
from src.models.qwen import QwenModel, QWEN_2_5_0_5B_CONFIG, QWEN_2_5_1_5B_CONFIG |
|
|
from src.models.mistral import MistralModel, MISTRAL_7B_CONFIG |
|
|
|
|
|
|
|
|
class TestModelConfig: |
|
|
"""ModelConfigのテスト""" |
|
|
|
|
|
def test_config_is_immutable(self): |
|
|
"""設定がイミュータブルであることを確認""" |
|
|
config = ModelConfig( |
|
|
name="Test", |
|
|
model_id="test", |
|
|
embedding_dim=768, |
|
|
vocab_size=50000, |
|
|
) |
|
|
with pytest.raises(Exception): |
|
|
config.name = "Changed" |
|
|
|
|
|
def test_config_attributes(self): |
|
|
"""設定属性が正しく保持されることを確認""" |
|
|
config = ModelConfig( |
|
|
name="Test Model", |
|
|
model_id="test-model", |
|
|
embedding_dim=1024, |
|
|
vocab_size=30000, |
|
|
) |
|
|
assert config.name == "Test Model" |
|
|
assert config.model_id == "test-model" |
|
|
assert config.embedding_dim == 1024 |
|
|
assert config.vocab_size == 30000 |
|
|
|
|
|
|
|
|
class TestModelRegistry: |
|
|
"""ModelRegistryのテスト""" |
|
|
|
|
|
def test_list_models(self): |
|
|
"""登録済みモデル一覧が取得できることを確認""" |
|
|
models = ModelRegistry.list_models() |
|
|
assert len(models) > 0 |
|
|
assert DEFAULT_MODEL_KEY in models |
|
|
|
|
|
def test_get_model(self): |
|
|
"""モデルインスタンスが取得できることを確認""" |
|
|
model = ModelRegistry.get(DEFAULT_MODEL_KEY) |
|
|
assert isinstance(model, BaseLanguageModel) |
|
|
|
|
|
def test_get_nonexistent_model(self): |
|
|
"""存在しないモデルでKeyErrorが発生することを確認""" |
|
|
with pytest.raises(KeyError): |
|
|
ModelRegistry.get("nonexistent-model") |
|
|
|
|
|
def test_get_config(self): |
|
|
"""モデル設定が取得できることを確認""" |
|
|
config = ModelRegistry.get_config(DEFAULT_MODEL_KEY) |
|
|
assert config is not None |
|
|
assert isinstance(config, ModelConfig) |
|
|
|
|
|
def test_get_all_configs(self): |
|
|
"""すべてのモデル設定が取得できることを確認""" |
|
|
configs = ModelRegistry.get_all_configs() |
|
|
assert len(configs) > 0 |
|
|
for key, config in configs.items(): |
|
|
assert isinstance(config, ModelConfig) |
|
|
|
|
|
|
|
|
class TestGPT2Model: |
|
|
"""GPT2Modelのテスト""" |
|
|
|
|
|
def test_config(self): |
|
|
"""設定が正しいことを確認""" |
|
|
model = GPT2Model(GPT2_SMALL_CONFIG) |
|
|
assert model.config == GPT2_SMALL_CONFIG |
|
|
assert model.config.embedding_dim == 768 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = GPT2Model(GPT2_SMALL_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = GPT2Model(GPT2_SMALL_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 768) |
|
|
|
|
|
|
|
|
@pytest.mark.slow |
|
|
class TestGPT2ModelIntegration: |
|
|
"""GPT2Modelの統合テスト(モデルロードが必要)""" |
|
|
|
|
|
@pytest.fixture |
|
|
def loaded_model(self): |
|
|
"""ロード済みモデルを提供""" |
|
|
model = GPT2Model(GPT2_SMALL_CONFIG) |
|
|
model.load() |
|
|
return model |
|
|
|
|
|
def test_load(self, loaded_model): |
|
|
"""モデルがロードできることを確認""" |
|
|
assert loaded_model.is_loaded |
|
|
|
|
|
def test_forward_with_noise(self, loaded_model): |
|
|
"""順伝播が正しい形状を返すことを確認""" |
|
|
noise = loaded_model.generate_noise(seq_len=8) |
|
|
logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
|
|
|
|
|
assert logits.shape[0] == 1 |
|
|
assert logits.shape[1] == 8 |
|
|
assert logits.shape[2] == loaded_model.config.vocab_size |
|
|
|
|
|
def test_decode_indices(self, loaded_model): |
|
|
"""デコードが文字列リストを返すことを確認""" |
|
|
indices = [100, 200, 300] |
|
|
decoded = loaded_model.decode_indices(indices) |
|
|
|
|
|
assert len(decoded) == 3 |
|
|
assert all(isinstance(s, str) for s in decoded) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGPTOSSModel: |
|
|
"""GPTOSSModelのテスト""" |
|
|
|
|
|
def test_config(self): |
|
|
"""設定が正しいことを確認""" |
|
|
model = GPTOSSModel(GPT_OSS_20B_CONFIG) |
|
|
assert model.config == GPT_OSS_20B_CONFIG |
|
|
assert model.config.embedding_dim == 4096 |
|
|
assert model.config.vocab_size == 128000 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = GPTOSSModel(GPT_OSS_20B_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = GPTOSSModel(GPT_OSS_20B_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 4096) |
|
|
|
|
|
|
|
|
class TestPythiaModel: |
|
|
"""PythiaModelのテスト""" |
|
|
|
|
|
def test_config_410m(self): |
|
|
"""Pythia 410M設定が正しいことを確認""" |
|
|
model = PythiaModel(PYTHIA_410M_CONFIG) |
|
|
assert model.config == PYTHIA_410M_CONFIG |
|
|
assert model.config.embedding_dim == 1024 |
|
|
assert model.config.vocab_size == 50304 |
|
|
|
|
|
def test_config_1b(self): |
|
|
"""Pythia 1B設定が正しいことを確認""" |
|
|
model = PythiaModel(PYTHIA_1B_CONFIG) |
|
|
assert model.config == PYTHIA_1B_CONFIG |
|
|
assert model.config.embedding_dim == 2048 |
|
|
assert model.config.vocab_size == 50304 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = PythiaModel(PYTHIA_410M_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = PythiaModel(PYTHIA_410M_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 1024) |
|
|
|
|
|
|
|
|
class TestOLMoModel: |
|
|
"""OLMoModelのテスト""" |
|
|
|
|
|
def test_config_1b(self): |
|
|
"""OLMo 1B設定が正しいことを確認""" |
|
|
model = OLMoModel(OLMO_1B_CONFIG) |
|
|
assert model.config == OLMO_1B_CONFIG |
|
|
assert model.config.embedding_dim == 2048 |
|
|
assert model.config.vocab_size == 50304 |
|
|
|
|
|
def test_config_7b(self): |
|
|
"""OLMo 7B設定が正しいことを確認""" |
|
|
model = OLMoModel(OLMO_7B_CONFIG) |
|
|
assert model.config == OLMO_7B_CONFIG |
|
|
assert model.config.embedding_dim == 4096 |
|
|
assert model.config.vocab_size == 50304 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = OLMoModel(OLMO_1B_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = OLMoModel(OLMO_1B_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 2048) |
|
|
|
|
|
|
|
|
class TestBLOOMModel: |
|
|
"""BLOOMModelのテスト""" |
|
|
|
|
|
def test_config(self): |
|
|
"""BLOOM 560M設定が正しいことを確認""" |
|
|
model = BLOOMModel(BLOOM_560M_CONFIG) |
|
|
assert model.config == BLOOM_560M_CONFIG |
|
|
assert model.config.embedding_dim == 1024 |
|
|
assert model.config.vocab_size == 250880 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = BLOOMModel(BLOOM_560M_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = BLOOMModel(BLOOM_560M_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 1024) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLlamaModel: |
|
|
"""LlamaModelのテスト""" |
|
|
|
|
|
def test_config_1b(self): |
|
|
"""Llama 3.2 1B設定が正しいことを確認""" |
|
|
model = LlamaModel(LLAMA_3_2_1B_CONFIG) |
|
|
assert model.config == LLAMA_3_2_1B_CONFIG |
|
|
assert model.config.embedding_dim == 2048 |
|
|
assert model.config.vocab_size == 128256 |
|
|
|
|
|
def test_config_3b(self): |
|
|
"""Llama 3.2 3B設定が正しいことを確認""" |
|
|
model = LlamaModel(LLAMA_3_2_3B_CONFIG) |
|
|
assert model.config == LLAMA_3_2_3B_CONFIG |
|
|
assert model.config.embedding_dim == 3072 |
|
|
assert model.config.vocab_size == 128256 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = LlamaModel(LLAMA_3_2_1B_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = LlamaModel(LLAMA_3_2_1B_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 2048) |
|
|
|
|
|
|
|
|
class TestQwenModel: |
|
|
"""QwenModelのテスト""" |
|
|
|
|
|
def test_config_0_5b(self): |
|
|
"""Qwen2.5 0.5B設定が正しいことを確認""" |
|
|
model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
|
|
assert model.config == QWEN_2_5_0_5B_CONFIG |
|
|
assert model.config.embedding_dim == 896 |
|
|
assert model.config.vocab_size == 151936 |
|
|
|
|
|
def test_config_1_5b(self): |
|
|
"""Qwen2.5 1.5B設定が正しいことを確認""" |
|
|
model = QwenModel(QWEN_2_5_1_5B_CONFIG) |
|
|
assert model.config == QWEN_2_5_1_5B_CONFIG |
|
|
assert model.config.embedding_dim == 1536 |
|
|
assert model.config.vocab_size == 151936 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 896) |
|
|
|
|
|
|
|
|
class TestMistralModel: |
|
|
"""MistralModelのテスト""" |
|
|
|
|
|
def test_config(self): |
|
|
"""Mistral 7B設定が正しいことを確認""" |
|
|
model = MistralModel(MISTRAL_7B_CONFIG) |
|
|
assert model.config == MISTRAL_7B_CONFIG |
|
|
assert model.config.embedding_dim == 4096 |
|
|
assert model.config.vocab_size == 32768 |
|
|
|
|
|
def test_is_loaded_initial(self): |
|
|
"""初期状態ではロードされていないことを確認""" |
|
|
model = MistralModel(MISTRAL_7B_CONFIG) |
|
|
assert not model.is_loaded |
|
|
|
|
|
def test_generate_noise(self): |
|
|
"""ノイズ生成が正しい形状であることを確認""" |
|
|
model = MistralModel(MISTRAL_7B_CONFIG) |
|
|
noise = model.generate_noise(seq_len=16, batch_size=2) |
|
|
assert noise.shape == (2, 16, 4096) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestModelRegistryNewModels: |
|
|
"""新規追加モデルのレジストリテスト""" |
|
|
|
|
|
@pytest.mark.parametrize("model_key", [ |
|
|
"gpt-oss-20b", |
|
|
"pythia-410m", |
|
|
"pythia-1b", |
|
|
"olmo-1b", |
|
|
"olmo-7b", |
|
|
"bloom-560m", |
|
|
"llama-3.2-1b", |
|
|
"llama-3.2-3b", |
|
|
"qwen2.5-0.5b", |
|
|
"qwen2.5-1.5b", |
|
|
"mistral-7b", |
|
|
]) |
|
|
def test_model_registered(self, model_key): |
|
|
"""新モデルがレジストリに登録されていることを確認""" |
|
|
models = ModelRegistry.list_models() |
|
|
assert model_key in models |
|
|
|
|
|
@pytest.mark.parametrize("model_key", [ |
|
|
"gpt-oss-20b", |
|
|
"pythia-410m", |
|
|
"pythia-1b", |
|
|
"olmo-1b", |
|
|
"olmo-7b", |
|
|
"bloom-560m", |
|
|
"llama-3.2-1b", |
|
|
"llama-3.2-3b", |
|
|
"qwen2.5-0.5b", |
|
|
"qwen2.5-1.5b", |
|
|
"mistral-7b", |
|
|
]) |
|
|
def test_model_instance_creation(self, model_key): |
|
|
"""新モデルのインスタンスが作成できることを確認""" |
|
|
model = ModelRegistry.get(model_key) |
|
|
assert isinstance(model, BaseLanguageModel) |
|
|
assert not model.is_loaded |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.slow |
|
|
class TestPythiaModelIntegration: |
|
|
"""Pythiaモデルの統合テスト(小さいモデルで代表テスト)""" |
|
|
|
|
|
@pytest.fixture |
|
|
def loaded_model(self): |
|
|
"""ロード済みモデルを提供""" |
|
|
model = PythiaModel(PYTHIA_410M_CONFIG) |
|
|
model.load() |
|
|
return model |
|
|
|
|
|
def test_load(self, loaded_model): |
|
|
"""モデルがロードできることを確認""" |
|
|
assert loaded_model.is_loaded |
|
|
|
|
|
def test_forward_with_noise(self, loaded_model): |
|
|
"""順伝播が正しい形状を返すことを確認""" |
|
|
noise = loaded_model.generate_noise(seq_len=8) |
|
|
logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
|
|
|
|
|
assert logits.shape[0] == 1 |
|
|
assert logits.shape[1] == 8 |
|
|
assert logits.shape[2] == loaded_model.config.vocab_size |
|
|
|
|
|
def test_decode_indices(self, loaded_model): |
|
|
"""デコードが文字列リストを返すことを確認""" |
|
|
indices = [100, 200, 300] |
|
|
decoded = loaded_model.decode_indices(indices) |
|
|
|
|
|
assert len(decoded) == 3 |
|
|
assert all(isinstance(s, str) for s in decoded) |
|
|
|
|
|
|
|
|
@pytest.mark.slow |
|
|
class TestBLOOMModelIntegration: |
|
|
"""BLOOMモデルの統合テスト""" |
|
|
|
|
|
@pytest.fixture |
|
|
def loaded_model(self): |
|
|
"""ロード済みモデルを提供""" |
|
|
model = BLOOMModel(BLOOM_560M_CONFIG) |
|
|
model.load() |
|
|
return model |
|
|
|
|
|
def test_load(self, loaded_model): |
|
|
"""モデルがロードできることを確認""" |
|
|
assert loaded_model.is_loaded |
|
|
|
|
|
def test_forward_with_noise(self, loaded_model): |
|
|
"""順伝播が正しい形状を返すことを確認""" |
|
|
noise = loaded_model.generate_noise(seq_len=8) |
|
|
logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
|
|
|
|
|
assert logits.shape[0] == 1 |
|
|
assert logits.shape[1] == 8 |
|
|
assert logits.shape[2] == loaded_model.config.vocab_size |
|
|
|
|
|
|
|
|
@pytest.mark.slow |
|
|
class TestQwenModelIntegration: |
|
|
"""Qwenモデルの統合テスト(小さいモデルで代表テスト)""" |
|
|
|
|
|
@pytest.fixture |
|
|
def loaded_model(self): |
|
|
"""ロード済みモデルを提供""" |
|
|
model = QwenModel(QWEN_2_5_0_5B_CONFIG) |
|
|
model.load() |
|
|
return model |
|
|
|
|
|
def test_load(self, loaded_model): |
|
|
"""モデルがロードできることを確認""" |
|
|
assert loaded_model.is_loaded |
|
|
|
|
|
def test_forward_with_noise(self, loaded_model): |
|
|
"""順伝播が正しい形状を返すことを確認""" |
|
|
noise = loaded_model.generate_noise(seq_len=8) |
|
|
logits, corrupted_logits = loaded_model.forward_with_noise(noise) |
|
|
|
|
|
assert logits.shape[0] == 1 |
|
|
assert logits.shape[1] == 8 |
|
|
assert logits.shape[2] == loaded_model.config.vocab_size |
|
|
|