File size: 2,664 Bytes
233caeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""

RNN Model Architecture for CIFAR-10 Classification

"""
import torch
import torch.nn as nn
import config


class CIFAR10RNN(nn.Module):
    """

    Recurrent Neural Network (LSTM) for CIFAR-10 classification

    

    Architecture:

    - Input sequence: 32 rows of 32x3 pixels (= 96 features per step)

    - Bidirectional LSTM layers

    - Fully connected layer for classification

    """
    
    def __init__(self, input_size=96, hidden_size=256, num_layers=2, num_classes=10):
        super(CIFAR10RNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # LSTM Layer
        # batch_first=True means input shape is (batch, seq, feature)
        self.lstm = nn.LSTM(
            input_size, 
            hidden_size, 
            num_layers, 
            batch_first=True, 
            bidirectional=True,
            dropout=config.RNN_DROPOUT if num_layers > 1 else 0
        )
        
        # Fully Connected Layer
        # * 2 because of bidirectional
        self.fc = nn.Sequential(
            nn.Linear(hidden_size * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        # x shape: (batch, 3, 32, 32)
        # Convert to: (batch, seq_len=32, input_size=96)
        batch_size = x.size(0)
        
        # Rearrange image rows into a sequence
        # (batch, 3, 32, 32) -> (batch, 32, 3, 32) -> (batch, 32, 96)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, 32, -1)
        
        # LSTM Forward pass
        # out: tensor of shape (batch, seq_len, hidden_size * 2)
        out, _ = self.lstm(x)
        
        # Take the output of the last time step
        out = out[:, -1, :]
        
        # Classification
        out = self.fc(out)
        
        return out


def get_model(num_classes=10, device='cpu'):
    """

    Create and return the RNN model

    

    Args:

        num_classes (int): Number of output classes

        device (str or torch.device): Device to load the model on

        

    Returns:

        CIFAR10RNN: The RNN model

    """
    model = CIFAR10RNN(
        input_size=32*3, 
        hidden_size=config.HIDDEN_SIZE, 
        num_layers=config.NUM_LAYERS, 
        num_classes=num_classes
    )
    model = model.to(device)
    return model


def count_parameters(model):
    """

    Count the number of trainable parameters in the model

    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)