luciayen commited on
Commit
e15d7de
·
verified ·
1 Parent(s): a2f07f7

Fix: Upload full SignVLM architecture to model.py

Browse files
Files changed (1) hide show
  1. model.py +55 -1
model.py CHANGED
@@ -1,4 +1,58 @@
1
 
2
  import torch
3
  import torch.nn as nn
4
- # ... Paste your full model class code here ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class PositionalEmbedding(nn.Module):
7
+ def __init__(self, sequence_length, embed_dim):
8
+ super().__init__()
9
+ self.pos_emb = nn.Embedding(sequence_length, embed_dim)
10
+
11
+ def forward(self, x):
12
+ positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
13
+ return x + self.pos_emb(positions)
14
+
15
+ class TransformerBlock(nn.Module):
16
+ def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
17
+ super().__init__()
18
+ self.att = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
19
+ self.ffn = nn.Sequential(
20
+ nn.Linear(embed_dim, ff_dim),
21
+ nn.ReLU(),
22
+ nn.Linear(ff_dim, embed_dim)
23
+ )
24
+ self.layernorm1 = nn.LayerNormalization(embed_dim)
25
+ self.layernorm2 = nn.LayerNormalization(embed_dim)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ def forward(self, x):
29
+ attn_output, _ = self.att(x, x, x)
30
+ x = self.layernorm1(x + self.dropout(attn_output))
31
+ ffn_output = self.ffn(x)
32
+ x = self.layernorm2(x + self.dropout(ffn_output))
33
+ return x
34
+
35
+ class SignVLM(nn.Module):
36
+ def __init__(self, input_shape=(64, 225), num_classes=60, embed_dim=256, num_heads=8, ff_dim=512):
37
+ super().__init__()
38
+ self.dense_proj = nn.Linear(input_shape[1], embed_dim)
39
+ self.pos_emb = PositionalEmbedding(input_shape[0], embed_dim)
40
+
41
+ self.transformer_blocks = nn.ModuleList([
42
+ TransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(4)
43
+ ])
44
+
45
+ self.dropout = nn.Dropout(0.5)
46
+ self.classifier = nn.Linear(embed_dim, num_classes)
47
+
48
+ def forward(self, x):
49
+ # x shape: (batch, 64, 225)
50
+ x = self.dense_proj(x)
51
+ x = self.pos_emb(x)
52
+
53
+ for block in self.transformer_blocks:
54
+ x = block(x)
55
+
56
+ x = x.mean(dim=1) # Global Average Pooling
57
+ x = self.dropout(x)
58
+ return self.classifier(x)