File size: 2,253 Bytes
198ccb0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | """Tests for model architectures."""
import pytest
import torch
from models.simple_classifier import SimpleClassifier
from models.cnn_classifier import CNNClassifier
def test_simple_classifier_title_only():
"""Test SimpleClassifier with title only."""
model = SimpleClassifier(
vocab_size=1000,
embedding_dim=100,
output_dim=50,
use_snippet=False
)
batch_size = 4
seq_len = 20
title = torch.randint(0, 1000, (batch_size, seq_len))
output = model(title)
assert output.shape == (batch_size, 50)
assert not torch.isnan(output).any()
def test_simple_classifier_with_snippet():
"""Test SimpleClassifier with title and snippet."""
model = SimpleClassifier(
vocab_size=1000,
embedding_dim=100,
output_dim=50,
use_snippet=True
)
batch_size = 4
title_len = 20
snippet_len = 50
title = torch.randint(0, 1000, (batch_size, title_len))
snippet = torch.randint(0, 1000, (batch_size, snippet_len))
output = model(title, snippet)
assert output.shape == (batch_size, 50)
assert not torch.isnan(output).any()
def test_cnn_classifier():
"""Test CNNClassifier."""
model = CNNClassifier(
vocab_size=1000,
embedding_dim=100,
output_dim=50,
max_title_len=20,
max_snippet_len=50,
conv_channels=[64, 128],
kernel_sizes=[3, 3],
)
batch_size = 4
title = torch.randint(0, 1000, (batch_size, 20))
snippet = torch.randint(0, 1000, (batch_size, 50))
output = model(title, snippet)
assert output.shape == (batch_size, 50)
assert not torch.isnan(output).any()
def test_cnn_classifier_shape_consistency():
"""Test that CNN classifier handles expected input sizes correctly."""
model = CNNClassifier(
vocab_size=1000,
embedding_dim=100,
output_dim=50,
max_title_len=20,
max_snippet_len=50,
)
# Test with expected max sequence lengths (model is designed for fixed sizes)
title = torch.randint(0, 1000, (2, 20))
snippet = torch.randint(0, 1000, (2, 50))
output = model(title, snippet)
assert output.shape == (2, 50)
|