import torch import torch.nn as nn class SignVLM(nn.Module): def __init__(self, input_dim=225, hidden_dim=512, num_heads=8, num_layers=4, num_classes=60): super(SignVLM, self).__init__() # Layer 0: Projection to 512 self.feature_extractor = nn.Sequential( nn.Linear(input_dim, hidden_dim), # Layer 1: BatchNorm across the 64 frames nn.BatchNorm1d(64) ) encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=1024, batch_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.classifier = nn.Sequential( nn.Linear(hidden_dim, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def forward(self, x): # x: [batch, 64, 225] x = self.feature_extractor[0](x) # Linear projection: [batch, 64, 512] x = self.feature_extractor[1](x) # BatchNorm1d: [batch, 64, 512] (normalizing dim 1) x = self.transformer(x) x = x.mean(dim=1) # Global Average Pooling return self.classifier(x)