File size: 2,489 Bytes
d1033d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
ジェネレータ関連のテスト
"""
import pytest
import torch

from src.generators.debris_generator import DebrisGenerator, DebrisResult
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG


class TestDebrisResult:
    """DebrisResultのテスト"""

    def test_result_attributes(self):
        """結果属性が正しく保持されることを確認"""
        result = DebrisResult(
            debris=["hello", "world"],
            seed=12345,
            noise=torch.randn(1, 32, 768),
            logits=torch.randn(1, 32, 50257),
            corrupted_logits=torch.randn(1, 32, 50257),
        )

        assert result.debris == ["hello", "world"]
        assert result.seed == 12345
        assert result.noise.shape == (1, 32, 768)


class TestDebrisGenerator:
    """DebrisGeneratorのテスト"""

    @pytest.fixture
    def generator(self):
        """ジェネレータインスタンスを提供"""
        model = GPT2Model(GPT2_SMALL_CONFIG)
        return DebrisGenerator(model)

    def test_model_property(self, generator):
        """モデルプロパティが正しいことを確認"""
        assert generator.model is not None
        assert generator.model.config == GPT2_SMALL_CONFIG


@pytest.mark.slow
class TestDebrisGeneratorIntegration:
    """DebrisGeneratorの統合テスト"""

    @pytest.fixture
    def generator(self):
        """ロード済みジェネレータを提供"""
        model = GPT2Model(GPT2_SMALL_CONFIG)
        model.load()
        return DebrisGenerator(model)

    def test_generate_with_seed(self, generator):
        """シード指定で生成できることを確認"""
        result = generator.generate(seed=42, seq_len=8)

        assert isinstance(result, DebrisResult)
        assert result.seed == 42
        assert len(result.debris) == 8

    def test_generate_reproducible(self, generator):
        """同じシードで同じ結果が得られることを確認"""
        result1 = generator.generate(seed=12345, seq_len=8)
        result2 = generator.generate(seed=12345, seq_len=8)

        assert result1.debris == result2.debris

    def test_generate_different_seeds(self, generator):
        """異なるシードで異なる結果が得られることを確認"""
        result1 = generator.generate(seed=11111, seq_len=8)
        result2 = generator.generate(seed=22222, seq_len=8)

        # 完全一致する確率は極めて低い
        assert result1.debris != result2.debris