File size: 2,253 Bytes
198ccb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Tests for model architectures."""

import pytest
import torch
from models.simple_classifier import SimpleClassifier
from models.cnn_classifier import CNNClassifier


def test_simple_classifier_title_only():
    """Test SimpleClassifier with title only."""
    model = SimpleClassifier(
        vocab_size=1000,
        embedding_dim=100,
        output_dim=50,
        use_snippet=False
    )
    
    batch_size = 4
    seq_len = 20
    title = torch.randint(0, 1000, (batch_size, seq_len))
    
    output = model(title)
    
    assert output.shape == (batch_size, 50)
    assert not torch.isnan(output).any()


def test_simple_classifier_with_snippet():
    """Test SimpleClassifier with title and snippet."""
    model = SimpleClassifier(
        vocab_size=1000,
        embedding_dim=100,
        output_dim=50,
        use_snippet=True
    )
    
    batch_size = 4
    title_len = 20
    snippet_len = 50
    title = torch.randint(0, 1000, (batch_size, title_len))
    snippet = torch.randint(0, 1000, (batch_size, snippet_len))
    
    output = model(title, snippet)
    
    assert output.shape == (batch_size, 50)
    assert not torch.isnan(output).any()


def test_cnn_classifier():
    """Test CNNClassifier."""
    model = CNNClassifier(
        vocab_size=1000,
        embedding_dim=100,
        output_dim=50,
        max_title_len=20,
        max_snippet_len=50,
        conv_channels=[64, 128],
        kernel_sizes=[3, 3],
    )
    
    batch_size = 4
    title = torch.randint(0, 1000, (batch_size, 20))
    snippet = torch.randint(0, 1000, (batch_size, 50))
    
    output = model(title, snippet)
    
    assert output.shape == (batch_size, 50)
    assert not torch.isnan(output).any()


def test_cnn_classifier_shape_consistency():
    """Test that CNN classifier handles expected input sizes correctly."""
    model = CNNClassifier(
        vocab_size=1000,
        embedding_dim=100,
        output_dim=50,
        max_title_len=20,
        max_snippet_len=50,
    )
    
    # Test with expected max sequence lengths (model is designed for fixed sizes)
    title = torch.randint(0, 1000, (2, 20))
    snippet = torch.randint(0, 1000, (2, 50))
    
    output = model(title, snippet)
    assert output.shape == (2, 50)