|
|
""" |
|
|
ジェネレータ関連のテスト |
|
|
""" |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from src.generators.debris_generator import DebrisGenerator, DebrisResult |
|
|
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG |
|
|
|
|
|
|
|
|
class TestDebrisResult: |
|
|
"""DebrisResultのテスト""" |
|
|
|
|
|
def test_result_attributes(self): |
|
|
"""結果属性が正しく保持されることを確認""" |
|
|
result = DebrisResult( |
|
|
debris=["hello", "world"], |
|
|
seed=12345, |
|
|
noise=torch.randn(1, 32, 768), |
|
|
logits=torch.randn(1, 32, 50257), |
|
|
corrupted_logits=torch.randn(1, 32, 50257), |
|
|
) |
|
|
|
|
|
assert result.debris == ["hello", "world"] |
|
|
assert result.seed == 12345 |
|
|
assert result.noise.shape == (1, 32, 768) |
|
|
|
|
|
|
|
|
class TestDebrisGenerator: |
|
|
"""DebrisGeneratorのテスト""" |
|
|
|
|
|
@pytest.fixture |
|
|
def generator(self): |
|
|
"""ジェネレータインスタンスを提供""" |
|
|
model = GPT2Model(GPT2_SMALL_CONFIG) |
|
|
return DebrisGenerator(model) |
|
|
|
|
|
def test_model_property(self, generator): |
|
|
"""モデルプロパティが正しいことを確認""" |
|
|
assert generator.model is not None |
|
|
assert generator.model.config == GPT2_SMALL_CONFIG |
|
|
|
|
|
|
|
|
@pytest.mark.slow |
|
|
class TestDebrisGeneratorIntegration: |
|
|
"""DebrisGeneratorの統合テスト""" |
|
|
|
|
|
@pytest.fixture |
|
|
def generator(self): |
|
|
"""ロード済みジェネレータを提供""" |
|
|
model = GPT2Model(GPT2_SMALL_CONFIG) |
|
|
model.load() |
|
|
return DebrisGenerator(model) |
|
|
|
|
|
def test_generate_with_seed(self, generator): |
|
|
"""シード指定で生成できることを確認""" |
|
|
result = generator.generate(seed=42, seq_len=8) |
|
|
|
|
|
assert isinstance(result, DebrisResult) |
|
|
assert result.seed == 42 |
|
|
assert len(result.debris) == 8 |
|
|
|
|
|
def test_generate_reproducible(self, generator): |
|
|
"""同じシードで同じ結果が得られることを確認""" |
|
|
result1 = generator.generate(seed=12345, seq_len=8) |
|
|
result2 = generator.generate(seed=12345, seq_len=8) |
|
|
|
|
|
assert result1.debris == result2.debris |
|
|
|
|
|
def test_generate_different_seeds(self, generator): |
|
|
"""異なるシードで異なる結果が得られることを確認""" |
|
|
result1 = generator.generate(seed=11111, seq_len=8) |
|
|
result2 = generator.generate(seed=22222, seq_len=8) |
|
|
|
|
|
|
|
|
assert result1.debris != result2.debris |
|
|
|