CASL-TransSLR / model.py
luciayen's picture
Final architecture alignment: BatchNorm(64) for temporal normalization
558c6ba verified
import torch
import torch.nn as nn
class SignVLM(nn.Module):
def __init__(self, input_dim=225, hidden_dim=512, num_heads=8, num_layers=4, num_classes=60):
super(SignVLM, self).__init__()
# Layer 0: Projection to 512
self.feature_extractor = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
# Layer 1: BatchNorm across the 64 frames
nn.BatchNorm1d(64)
)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=1024,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
# x: [batch, 64, 225]
x = self.feature_extractor[0](x) # Linear projection: [batch, 64, 512]
x = self.feature_extractor[1](x) # BatchNorm1d: [batch, 64, 512] (normalizing dim 1)
x = self.transformer(x)
x = x.mean(dim=1) # Global Average Pooling
return self.classifier(x)