| import torch | |
| import torch.nn as nn | |
| class SignVLM(nn.Module): | |
| def __init__(self, input_dim=225, hidden_dim=512, num_heads=8, num_layers=4, num_classes=60): | |
| super(SignVLM, self).__init__() | |
| # Layer 0: Projection to 512 | |
| self.feature_extractor = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| # Layer 1: BatchNorm across the 64 frames | |
| nn.BatchNorm1d(64) | |
| ) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=hidden_dim, | |
| nhead=num_heads, | |
| dim_feedforward=1024, | |
| batch_first=True | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| # x: [batch, 64, 225] | |
| x = self.feature_extractor[0](x) # Linear projection: [batch, 64, 512] | |
| x = self.feature_extractor[1](x) # BatchNorm1d: [batch, 64, 512] (normalizing dim 1) | |
| x = self.transformer(x) | |
| x = x.mean(dim=1) # Global Average Pooling | |
| return self.classifier(x) | |