File size: 2,211 Bytes
aefdcf0
 
 
 
 
 
185b05e
f6d689c
185b05e
 
aefdcf0
 
 
 
 
f6d689c
aefdcf0
 
 
 
f6d689c
aefdcf0
 
f6d689c
aefdcf0
 
f6d689c
aefdcf0
 
 
f6d689c
aefdcf0
 
 
 
f6d689c
aefdcf0
 
 
 
 
f6d689c
aefdcf0
f6d689c
aefdcf0
 
f6d689c
aefdcf0
 
f6d689c
aefdcf0
 
 
 
 
f6d689c
aefdcf0
 
 
f6d689c
aefdcf0
 
f6d689c
aefdcf0
 
f6d689c
aefdcf0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# tests/test_models/test_positional_encoding.py

"""
Tests for positional encoding.
"""

import matplotlib
import torch

matplotlib.use("Agg")  # use non-interactive backend for test environments
from src.models.positional_encoding import PositionalEncoding


class TestPositionalEncoding:
    """Test suite for PositionalEncoding."""

    def test_output_shape(self):
        """Test that output shape matches input shape."""
        d_model, max_len = 512, 5000
        batch_size, seq_len = 2, 100

        pos_enc = PositionalEncoding(d_model, max_len, dropout=0.0)
        x = torch.randn(batch_size, seq_len, d_model)

        output = pos_enc(x)
        assert output.shape == (batch_size, seq_len, d_model)

    def test_different_sequence_lengths(self):
        """Test with various sequence lengths."""
        pos_enc = PositionalEncoding(d_model=256, max_len=1000, dropout=0.0)

        for seq_len in [10, 50, 100, 500]:
            x = torch.randn(1, seq_len, 256)
            output = pos_enc(x)
            assert output.shape == (1, seq_len, 256)

    def test_dropout_changes_output(self):
        """Test that dropout is applied during training."""
        torch.manual_seed(42)
        pos_enc = PositionalEncoding(d_model=128, dropout=0.5)
        pos_enc.train()

        x = torch.randn(2, 10, 128)

        output1 = pos_enc(x)
        output2 = pos_enc(x)

        # Should be different due to dropout
        assert not torch.allclose(output1, output2)

        # In eval mode, should be deterministic
        pos_enc.eval()
        output3 = pos_enc(x)
        output4 = pos_enc(x)
        assert torch.allclose(output3, output4)

    def test_encoding_properties(self):
        """Test mathematical properties of encoding."""
        pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)

        # Get the raw encoding (without dropout)
        pe = pos_enc.pe[0]  # Remove batch dimension

        # Each row should have values in [-1, 1] (sin/cos range)
        assert (pe >= -1).all() and (pe <= 1).all()

        # Different positions should have different encodings
        assert not torch.allclose(pe[0], pe[1])
        assert not torch.allclose(pe[0], pe[50])