File size: 2,489 Bytes
d1033d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
"""
ジェネレータ関連のテスト
"""
import pytest
import torch
from src.generators.debris_generator import DebrisGenerator, DebrisResult
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG
class TestDebrisResult:
"""DebrisResultのテスト"""
def test_result_attributes(self):
"""結果属性が正しく保持されることを確認"""
result = DebrisResult(
debris=["hello", "world"],
seed=12345,
noise=torch.randn(1, 32, 768),
logits=torch.randn(1, 32, 50257),
corrupted_logits=torch.randn(1, 32, 50257),
)
assert result.debris == ["hello", "world"]
assert result.seed == 12345
assert result.noise.shape == (1, 32, 768)
class TestDebrisGenerator:
"""DebrisGeneratorのテスト"""
@pytest.fixture
def generator(self):
"""ジェネレータインスタンスを提供"""
model = GPT2Model(GPT2_SMALL_CONFIG)
return DebrisGenerator(model)
def test_model_property(self, generator):
"""モデルプロパティが正しいことを確認"""
assert generator.model is not None
assert generator.model.config == GPT2_SMALL_CONFIG
@pytest.mark.slow
class TestDebrisGeneratorIntegration:
"""DebrisGeneratorの統合テスト"""
@pytest.fixture
def generator(self):
"""ロード済みジェネレータを提供"""
model = GPT2Model(GPT2_SMALL_CONFIG)
model.load()
return DebrisGenerator(model)
def test_generate_with_seed(self, generator):
"""シード指定で生成できることを確認"""
result = generator.generate(seed=42, seq_len=8)
assert isinstance(result, DebrisResult)
assert result.seed == 42
assert len(result.debris) == 8
def test_generate_reproducible(self, generator):
"""同じシードで同じ結果が得られることを確認"""
result1 = generator.generate(seed=12345, seq_len=8)
result2 = generator.generate(seed=12345, seq_len=8)
assert result1.debris == result2.debris
def test_generate_different_seeds(self, generator):
"""異なるシードで異なる結果が得られることを確認"""
result1 = generator.generate(seed=11111, seq_len=8)
result2 = generator.generate(seed=22222, seq_len=8)
# 完全一致する確率は極めて低い
assert result1.debris != result2.debris
|