|
|
"""
|
|
|
RNN Model Architecture for CIFAR-10 Classification
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import config
|
|
|
|
|
|
|
|
|
class CIFAR10RNN(nn.Module):
|
|
|
"""
|
|
|
Recurrent Neural Network (LSTM) for CIFAR-10 classification
|
|
|
|
|
|
Architecture:
|
|
|
- Input sequence: 32 rows of 32x3 pixels (= 96 features per step)
|
|
|
- Bidirectional LSTM layers
|
|
|
- Fully connected layer for classification
|
|
|
"""
|
|
|
|
|
|
def __init__(self, input_size=96, hidden_size=256, num_layers=2, num_classes=10):
|
|
|
super(CIFAR10RNN, self).__init__()
|
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
self.num_layers = num_layers
|
|
|
|
|
|
|
|
|
|
|
|
self.lstm = nn.LSTM(
|
|
|
input_size,
|
|
|
hidden_size,
|
|
|
num_layers,
|
|
|
batch_first=True,
|
|
|
bidirectional=True,
|
|
|
dropout=config.RNN_DROPOUT if num_layers > 1 else 0
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.fc = nn.Sequential(
|
|
|
nn.Linear(hidden_size * 2, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.3),
|
|
|
nn.Linear(512, num_classes)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
|
batch_size = x.size(0)
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
|
x = x.view(batch_size, 32, -1)
|
|
|
|
|
|
|
|
|
|
|
|
out, _ = self.lstm(x)
|
|
|
|
|
|
|
|
|
out = out[:, -1, :]
|
|
|
|
|
|
|
|
|
out = self.fc(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
def get_model(num_classes=10, device='cpu'):
|
|
|
"""
|
|
|
Create and return the RNN model
|
|
|
|
|
|
Args:
|
|
|
num_classes (int): Number of output classes
|
|
|
device (str or torch.device): Device to load the model on
|
|
|
|
|
|
Returns:
|
|
|
CIFAR10RNN: The RNN model
|
|
|
"""
|
|
|
model = CIFAR10RNN(
|
|
|
input_size=32*3,
|
|
|
hidden_size=config.HIDDEN_SIZE,
|
|
|
num_layers=config.NUM_LAYERS,
|
|
|
num_classes=num_classes
|
|
|
)
|
|
|
model = model.to(device)
|
|
|
return model
|
|
|
|
|
|
|
|
|
def count_parameters(model):
|
|
|
"""
|
|
|
Count the number of trainable parameters in the model
|
|
|
"""
|
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|