""" ジェネレータ関連のテスト """ import pytest import torch from src.generators.debris_generator import DebrisGenerator, DebrisResult from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG class TestDebrisResult: """DebrisResultのテスト""" def test_result_attributes(self): """結果属性が正しく保持されることを確認""" result = DebrisResult( debris=["hello", "world"], seed=12345, noise=torch.randn(1, 32, 768), logits=torch.randn(1, 32, 50257), corrupted_logits=torch.randn(1, 32, 50257), ) assert result.debris == ["hello", "world"] assert result.seed == 12345 assert result.noise.shape == (1, 32, 768) class TestDebrisGenerator: """DebrisGeneratorのテスト""" @pytest.fixture def generator(self): """ジェネレータインスタンスを提供""" model = GPT2Model(GPT2_SMALL_CONFIG) return DebrisGenerator(model) def test_model_property(self, generator): """モデルプロパティが正しいことを確認""" assert generator.model is not None assert generator.model.config == GPT2_SMALL_CONFIG @pytest.mark.slow class TestDebrisGeneratorIntegration: """DebrisGeneratorの統合テスト""" @pytest.fixture def generator(self): """ロード済みジェネレータを提供""" model = GPT2Model(GPT2_SMALL_CONFIG) model.load() return DebrisGenerator(model) def test_generate_with_seed(self, generator): """シード指定で生成できることを確認""" result = generator.generate(seed=42, seq_len=8) assert isinstance(result, DebrisResult) assert result.seed == 42 assert len(result.debris) == 8 def test_generate_reproducible(self, generator): """同じシードで同じ結果が得られることを確認""" result1 = generator.generate(seed=12345, seq_len=8) result2 = generator.generate(seed=12345, seq_len=8) assert result1.debris == result2.debris def test_generate_different_seeds(self, generator): """異なるシードで異なる結果が得られることを確認""" result1 = generator.generate(seed=11111, seq_len=8) result2 = generator.generate(seed=22222, seq_len=8) # 完全一致する確率は極めて低い assert result1.debris != result2.debris