''' Models definition ''' import math import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class CRNN(nn.Module): def __init__(self, num_classes=100, hidden_size=256, resnet_pretrained_weights=None): super(CRNN, self).__init__() resnet = models.resnet18(weights=resnet_pretrained_weights) self.cnn = nn.Sequential(*list(resnet.children())[:-2]) self.feature_dim = 512 self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.rnn = nn.LSTM(self.feature_dim, hidden_size, batch_first=True) self.dropout = nn.Dropout(0.3) self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x): B, T, C, H, W = x.size() x = x.view(B * T, C, H, W) features = self.cnn(x) pooled = self.pool(features).squeeze(-1).squeeze(-1) seq = pooled.view(B, T, self.feature_dim) rnn_out, _ = self.rnn(seq) final = self.dropout(rnn_out[:, -1, :]) return self.fc(final) class PositionalEncoding(nn.Module): '''Positional encoding for temporal sequence''' def __init__(self, d_model, max_len=64, dropout=0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer("pe", pe) def forward(self, x): # x: (B, T, d_model) x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class AttentionPooling(nn.Module): '''Attention pooling layer''' def __init__(self, dim): super().__init__() self.attention = nn.Sequential( nn.Linear(dim, dim // 4), nn.Tanh(), nn.Linear(dim // 4, 1) ) def forward(self, x): # x: (B, T, dim) attn_weights = self.attention(x) # (B, T, 1) attn_weights = F.softmax(attn_weights, dim=1) pooled = torch.sum(attn_weights * x, dim=1) # (B, dim) return pooled class ConvNeXtTransformer(nn.Module): ''' ConvNeXt-Tiny + Transformer Input: (B, T, C, H, W) = (B, 16, 3, 224, 224) Output: (B, num_classes) = (B, 100) ''' def __init__(self, num_classes=100, hidden_size=256, convnext_pretrained_weights=None): super().__init__() # 1. ConvNeXt-Tiny Backbone convnext = models.convnext_tiny(weights=convnext_pretrained_weights) self.cnn = convnext.features self.pool = nn.AdaptiveAvgPool2d((1, 1)) # ConvNeXt-Tiny output = 768 self.feature_dim = 768 # 2. Positional Encoding self.pos_encoder = PositionalEncoding( d_model=self.feature_dim, max_len=64, dropout=0.1 ) # 3. Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=self.feature_dim, nhead=8, dim_feedforward=self.feature_dim * 4, dropout=0.3, activation="gelu", batch_first=True, norm_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2) # 4. Attention Pooling self.attention_pool = AttentionPooling(self.feature_dim) # 5. Classifier self.fc = nn.Sequential( nn.LayerNorm(self.feature_dim), nn.Dropout(0.4), nn.Linear(self.feature_dim, num_classes) ) self._init_weights() def _init_weights(self): for m in self.transformer.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) for m in self.attention_pool.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def freeze_convnext_features(self, freeze_until=3): for i in range(freeze_until + 1): for p in self.cnn[i].parameters(): p.requires_grad = False def forward(self, x): B, T, C, H, W = x.shape # CNN: (B, T, C, H, W) → (B, T, 768) x = x.view(B * T, C, H, W) x = self.cnn(x) x = self.pool(x) x = x.view(B, T, self.feature_dim) # Transformer: (B, T, 768) → (B, T, 768) x = self.pos_encoder(x) x = self.transformer(x) # Pooling: (B, T, 768) → (B, 768) x = self.attention_pool(x) # Classifier: (B, 768) → (B, num_classes) x = self.fc(x) return x