File size: 1,269 Bytes
a2f07f7 e15d7de 79a87a4 a9165d8 79a87a4 558c6ba 79a87a4 558c6ba 79a87a4 a9165d8 79a87a4 558c6ba e15d7de 79a87a4 a9165d8 e15d7de 558c6ba 79a87a4 558c6ba e15d7de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torch.nn as nn
class SignVLM(nn.Module):
def __init__(self, input_dim=225, hidden_dim=512, num_heads=8, num_layers=4, num_classes=60):
super(SignVLM, self).__init__()
# Layer 0: Projection to 512
self.feature_extractor = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
# Layer 1: BatchNorm across the 64 frames
nn.BatchNorm1d(64)
)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=1024,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
# x: [batch, 64, 225]
x = self.feature_extractor[0](x) # Linear projection: [batch, 64, 512]
x = self.feature_extractor[1](x) # BatchNorm1d: [batch, 64, 512] (normalizing dim 1)
x = self.transformer(x)
x = x.mean(dim=1) # Global Average Pooling
return self.classifier(x)
|