| '''
|
| Models definition
|
| '''
|
|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torchvision import models
|
|
|
|
|
| class CRNN(nn.Module):
|
| def __init__(self, num_classes=100, hidden_size=256, resnet_pretrained_weights=None):
|
| super(CRNN, self).__init__()
|
| resnet = models.resnet18(weights=resnet_pretrained_weights)
|
| self.cnn = nn.Sequential(*list(resnet.children())[:-2])
|
| self.feature_dim = 512
|
| self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| self.rnn = nn.LSTM(self.feature_dim, hidden_size, batch_first=True)
|
| self.dropout = nn.Dropout(0.3)
|
| self.fc = nn.Linear(hidden_size, num_classes)
|
|
|
| def forward(self, x):
|
| B, T, C, H, W = x.size()
|
| x = x.view(B * T, C, H, W)
|
| features = self.cnn(x)
|
| pooled = self.pool(features).squeeze(-1).squeeze(-1)
|
| seq = pooled.view(B, T, self.feature_dim)
|
| rnn_out, _ = self.rnn(seq)
|
| final = self.dropout(rnn_out[:, -1, :])
|
| return self.fc(final)
|
|
|
|
|
| class PositionalEncoding(nn.Module):
|
| '''Positional encoding for temporal sequence'''
|
| def __init__(self, d_model, max_len=64, dropout=0.1):
|
| super().__init__()
|
| self.dropout = nn.Dropout(p=dropout)
|
|
|
| pe = torch.zeros(max_len, d_model)
|
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(0)
|
|
|
| self.register_buffer("pe", pe)
|
|
|
| def forward(self, x):
|
|
|
| x = x + self.pe[:, :x.size(1), :]
|
| return self.dropout(x)
|
|
|
|
|
| class AttentionPooling(nn.Module):
|
| '''Attention pooling layer'''
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.attention = nn.Sequential(
|
| nn.Linear(dim, dim // 4),
|
| nn.Tanh(),
|
| nn.Linear(dim // 4, 1)
|
| )
|
|
|
| def forward(self, x):
|
|
|
| attn_weights = self.attention(x)
|
| attn_weights = F.softmax(attn_weights, dim=1)
|
| pooled = torch.sum(attn_weights * x, dim=1)
|
| return pooled
|
|
|
|
|
| class ConvNeXtTransformer(nn.Module):
|
| '''
|
| ConvNeXt-Tiny + Transformer
|
|
|
| Input: (B, T, C, H, W) = (B, 16, 3, 224, 224)
|
| Output: (B, num_classes) = (B, 100)
|
| '''
|
| def __init__(self, num_classes=100, hidden_size=256, convnext_pretrained_weights=None):
|
| super().__init__()
|
|
|
|
|
| convnext = models.convnext_tiny(weights=convnext_pretrained_weights)
|
| self.cnn = convnext.features
|
| self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
|
| self.feature_dim = 768
|
|
|
|
|
| self.pos_encoder = PositionalEncoding(
|
| d_model=self.feature_dim,
|
| max_len=64,
|
| dropout=0.1
|
| )
|
|
|
|
|
| encoder_layer = nn.TransformerEncoderLayer(
|
| d_model=self.feature_dim,
|
| nhead=8,
|
| dim_feedforward=self.feature_dim * 4,
|
| dropout=0.3,
|
| activation="gelu",
|
| batch_first=True,
|
| norm_first=True
|
| )
|
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
|
|
|
|
| self.attention_pool = AttentionPooling(self.feature_dim)
|
|
|
|
|
| self.fc = nn.Sequential(
|
| nn.LayerNorm(self.feature_dim),
|
| nn.Dropout(0.4),
|
| nn.Linear(self.feature_dim, num_classes)
|
| )
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| for m in self.transformer.modules():
|
| if isinstance(m, nn.Linear):
|
| nn.init.trunc_normal_(m.weight, std=0.02)
|
| if m.bias is not None:
|
| nn.init.zeros_(m.bias)
|
|
|
| for m in self.attention_pool.modules():
|
| if isinstance(m, nn.Linear):
|
| nn.init.trunc_normal_(m.weight, std=0.02)
|
| if m.bias is not None:
|
| nn.init.zeros_(m.bias)
|
|
|
| def freeze_convnext_features(self, freeze_until=3):
|
| for i in range(freeze_until + 1):
|
| for p in self.cnn[i].parameters():
|
| p.requires_grad = False
|
|
|
| def forward(self, x):
|
| B, T, C, H, W = x.shape
|
|
|
|
|
| x = x.view(B * T, C, H, W)
|
| x = self.cnn(x)
|
| x = self.pool(x)
|
| x = x.view(B, T, self.feature_dim)
|
|
|
|
|
| x = self.pos_encoder(x)
|
| x = self.transformer(x)
|
|
|
|
|
| x = self.attention_pool(x)
|
|
|
|
|
| x = self.fc(x)
|
|
|
| return x |