will / tests /test_generators.py
matt1847's picture
リファクタ: srcディレクトリ構造への移行とDocker対応
d1033d4
"""
ジェネレータ関連のテスト
"""
import pytest
import torch
from src.generators.debris_generator import DebrisGenerator, DebrisResult
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG
class TestDebrisResult:
"""DebrisResultのテスト"""
def test_result_attributes(self):
"""結果属性が正しく保持されることを確認"""
result = DebrisResult(
debris=["hello", "world"],
seed=12345,
noise=torch.randn(1, 32, 768),
logits=torch.randn(1, 32, 50257),
corrupted_logits=torch.randn(1, 32, 50257),
)
assert result.debris == ["hello", "world"]
assert result.seed == 12345
assert result.noise.shape == (1, 32, 768)
class TestDebrisGenerator:
"""DebrisGeneratorのテスト"""
@pytest.fixture
def generator(self):
"""ジェネレータインスタンスを提供"""
model = GPT2Model(GPT2_SMALL_CONFIG)
return DebrisGenerator(model)
def test_model_property(self, generator):
"""モデルプロパティが正しいことを確認"""
assert generator.model is not None
assert generator.model.config == GPT2_SMALL_CONFIG
@pytest.mark.slow
class TestDebrisGeneratorIntegration:
"""DebrisGeneratorの統合テスト"""
@pytest.fixture
def generator(self):
"""ロード済みジェネレータを提供"""
model = GPT2Model(GPT2_SMALL_CONFIG)
model.load()
return DebrisGenerator(model)
def test_generate_with_seed(self, generator):
"""シード指定で生成できることを確認"""
result = generator.generate(seed=42, seq_len=8)
assert isinstance(result, DebrisResult)
assert result.seed == 42
assert len(result.debris) == 8
def test_generate_reproducible(self, generator):
"""同じシードで同じ結果が得られることを確認"""
result1 = generator.generate(seed=12345, seq_len=8)
result2 = generator.generate(seed=12345, seq_len=8)
assert result1.debris == result2.debris
def test_generate_different_seeds(self, generator):
"""異なるシードで異なる結果が得られることを確認"""
result1 = generator.generate(seed=11111, seq_len=8)
result2 = generator.generate(seed=22222, seq_len=8)
# 完全一致する確率は極めて低い
assert result1.debris != result2.debris