| | """Tests for model architectures.""" |
| |
|
| | import pytest |
| | import torch |
| | from models.simple_classifier import SimpleClassifier |
| | from models.cnn_classifier import CNNClassifier |
| |
|
| |
|
| | def test_simple_classifier_title_only(): |
| | """Test SimpleClassifier with title only.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=False |
| | ) |
| | |
| | batch_size = 4 |
| | seq_len = 20 |
| | title = torch.randint(0, 1000, (batch_size, seq_len)) |
| | |
| | output = model(title) |
| | |
| | assert output.shape == (batch_size, 50) |
| | assert not torch.isnan(output).any() |
| |
|
| |
|
| | def test_simple_classifier_with_snippet(): |
| | """Test SimpleClassifier with title and snippet.""" |
| | model = SimpleClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | use_snippet=True |
| | ) |
| | |
| | batch_size = 4 |
| | title_len = 20 |
| | snippet_len = 50 |
| | title = torch.randint(0, 1000, (batch_size, title_len)) |
| | snippet = torch.randint(0, 1000, (batch_size, snippet_len)) |
| | |
| | output = model(title, snippet) |
| | |
| | assert output.shape == (batch_size, 50) |
| | assert not torch.isnan(output).any() |
| |
|
| |
|
| | def test_cnn_classifier(): |
| | """Test CNNClassifier.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50, |
| | conv_channels=[64, 128], |
| | kernel_sizes=[3, 3], |
| | ) |
| | |
| | batch_size = 4 |
| | title = torch.randint(0, 1000, (batch_size, 20)) |
| | snippet = torch.randint(0, 1000, (batch_size, 50)) |
| | |
| | output = model(title, snippet) |
| | |
| | assert output.shape == (batch_size, 50) |
| | assert not torch.isnan(output).any() |
| |
|
| |
|
| | def test_cnn_classifier_shape_consistency(): |
| | """Test that CNN classifier handles expected input sizes correctly.""" |
| | model = CNNClassifier( |
| | vocab_size=1000, |
| | embedding_dim=100, |
| | output_dim=50, |
| | max_title_len=20, |
| | max_snippet_len=50, |
| | ) |
| | |
| | |
| | title = torch.randint(0, 1000, (2, 20)) |
| | snippet = torch.randint(0, 1000, (2, 50)) |
| | |
| | output = model(title, snippet) |
| | assert output.shape == (2, 50) |
| |
|
| |
|