File size: 2,887 Bytes
2730fd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """Tests for model classes."""
import pytest
import torch
from namer.models import NamerTransformer, PositionalEncoding
from namer.utils import VOCABULARY
class TestPositionalEncoding:
"""Tests for PositionalEncoding module."""
def test_shape(self) -> None:
pe = PositionalEncoding(d_model=128)
x = torch.randn(2, 10, 128) # batch=2, seq=10, dim=128
out = pe(x)
assert out.shape == (2, 10, 128)
def test_adds_position(self) -> None:
pe = PositionalEncoding(d_model=64)
x = torch.zeros(1, 5, 64)
out = pe(x)
# Output should be non-zero due to positional encoding
assert not torch.allclose(out, x)
class TestNamerTransformer:
"""Tests for NamerTransformer model."""
@pytest.fixture
def model(self) -> NamerTransformer:
return NamerTransformer(
vocab_size=len(VOCABULARY),
max_output_len=20,
d_model=64,
nhead=4,
num_encoder_layers=2,
dim_feedforward=128,
dropout=0.0,
)
def test_forward_shape(self, model: NamerTransformer) -> None:
batch_size = 4
seq_len = 10
digits = torch.randint(0, 10, (batch_size, seq_len))
logits = model(digits)
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
def test_forward_with_padding(self, model: NamerTransformer) -> None:
batch_size = 2
seq_len = 10
digits = torch.full((batch_size, seq_len), 10) # All padding
digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
logits = model(digits)
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
def test_forward_with_negative_padding(self, model: NamerTransformer) -> None:
batch_size = 2
seq_len = 10
digits = torch.full((batch_size, seq_len), -1) # -1 padding
digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
logits = model(digits)
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
def test_output_is_logits(self, model: NamerTransformer) -> None:
digits = torch.randint(0, 10, (1, 5))
logits = model(digits)
# Logits should not be probabilities (no softmax applied)
assert not torch.all((logits >= 0) & (logits <= 1))
def test_gradient_flow(self, model: NamerTransformer) -> None:
digits = torch.randint(0, 10, (2, 5))
target = torch.randint(0, len(VOCABULARY), (2, model.max_output_len))
logits = model(digits)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, model.vocab_size),
target.view(-1)
)
loss.backward()
# Check that gradients exist
for param in model.parameters():
assert param.grad is not None
|