lucvantien1211's picture
Upload src folder, which contains modules and scripts
b20701a verified
'''
Models definition
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class CRNN(nn.Module):
def __init__(self, num_classes=100, hidden_size=256, resnet_pretrained_weights=None):
super(CRNN, self).__init__()
resnet = models.resnet18(weights=resnet_pretrained_weights)
self.cnn = nn.Sequential(*list(resnet.children())[:-2])
self.feature_dim = 512
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.rnn = nn.LSTM(self.feature_dim, hidden_size, batch_first=True)
self.dropout = nn.Dropout(0.3)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
B, T, C, H, W = x.size()
x = x.view(B * T, C, H, W)
features = self.cnn(x)
pooled = self.pool(features).squeeze(-1).squeeze(-1)
seq = pooled.view(B, T, self.feature_dim)
rnn_out, _ = self.rnn(seq)
final = self.dropout(rnn_out[:, -1, :])
return self.fc(final)
class PositionalEncoding(nn.Module):
'''Positional encoding for temporal sequence'''
def __init__(self, d_model, max_len=64, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x):
# x: (B, T, d_model)
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class AttentionPooling(nn.Module):
'''Attention pooling layer'''
def __init__(self, dim):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(dim, dim // 4),
nn.Tanh(),
nn.Linear(dim // 4, 1)
)
def forward(self, x):
# x: (B, T, dim)
attn_weights = self.attention(x) # (B, T, 1)
attn_weights = F.softmax(attn_weights, dim=1)
pooled = torch.sum(attn_weights * x, dim=1) # (B, dim)
return pooled
class ConvNeXtTransformer(nn.Module):
'''
ConvNeXt-Tiny + Transformer
Input: (B, T, C, H, W) = (B, 16, 3, 224, 224)
Output: (B, num_classes) = (B, 100)
'''
def __init__(self, num_classes=100, hidden_size=256, convnext_pretrained_weights=None):
super().__init__()
# 1. ConvNeXt-Tiny Backbone
convnext = models.convnext_tiny(weights=convnext_pretrained_weights)
self.cnn = convnext.features
self.pool = nn.AdaptiveAvgPool2d((1, 1))
# ConvNeXt-Tiny output = 768
self.feature_dim = 768
# 2. Positional Encoding
self.pos_encoder = PositionalEncoding(
d_model=self.feature_dim,
max_len=64,
dropout=0.1
)
# 3. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.feature_dim,
nhead=8,
dim_feedforward=self.feature_dim * 4,
dropout=0.3,
activation="gelu",
batch_first=True,
norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
# 4. Attention Pooling
self.attention_pool = AttentionPooling(self.feature_dim)
# 5. Classifier
self.fc = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Dropout(0.4),
nn.Linear(self.feature_dim, num_classes)
)
self._init_weights()
def _init_weights(self):
for m in self.transformer.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
for m in self.attention_pool.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def freeze_convnext_features(self, freeze_until=3):
for i in range(freeze_until + 1):
for p in self.cnn[i].parameters():
p.requires_grad = False
def forward(self, x):
B, T, C, H, W = x.shape
# CNN: (B, T, C, H, W) → (B, T, 768)
x = x.view(B * T, C, H, W)
x = self.cnn(x)
x = self.pool(x)
x = x.view(B, T, self.feature_dim)
# Transformer: (B, T, 768) → (B, T, 768)
x = self.pos_encoder(x)
x = self.transformer(x)
# Pooling: (B, T, 768) → (B, 768)
x = self.attention_pool(x)
# Classifier: (B, 768) → (B, num_classes)
x = self.fc(x)
return x