File size: 1,269 Bytes
a2f07f7
 
 
e15d7de
79a87a4
a9165d8
79a87a4
 
558c6ba
79a87a4
 
558c6ba
 
79a87a4
 
 
 
 
a9165d8
79a87a4
 
 
 
 
558c6ba
e15d7de
79a87a4
a9165d8
e15d7de
 
 
558c6ba
 
 
 
79a87a4
558c6ba
e15d7de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

import torch
import torch.nn as nn

class SignVLM(nn.Module):
    def __init__(self, input_dim=225, hidden_dim=512, num_heads=8, num_layers=4, num_classes=60):
        super(SignVLM, self).__init__()
        
        # Layer 0: Projection to 512
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            # Layer 1: BatchNorm across the 64 frames
            nn.BatchNorm1d(64) 
        )
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, 
            nhead=num_heads, 
            dim_feedforward=1024, 
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        # x: [batch, 64, 225]
        x = self.feature_extractor[0](x) # Linear projection: [batch, 64, 512]
        x = self.feature_extractor[1](x) # BatchNorm1d: [batch, 64, 512] (normalizing dim 1)
        
        x = self.transformer(x)
        x = x.mean(dim=1) # Global Average Pooling
        return self.classifier(x)