2264K commited on
Commit
3f1b0bf
·
verified ·
1 Parent(s): 7986cb2

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +209 -0
model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 2-A Toy PoC: 3-way Modality-Specific FFN (Vision + Audio + Text)
3
+ Shared Attention + ffn_vision / ffn_audio / ffn_text
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ CONFIG = {
11
+ "d_model": 256,
12
+ "n_heads": 4,
13
+ "ffn_dim": 512,
14
+ "n_layers": 6,
15
+ "vocab_size": 10000,
16
+ "patch_size": 16,
17
+ "max_seq_len": 512,
18
+ "dropout": 0.1,
19
+ "audio_feat_dim": 768,
20
+ }
21
+
22
+
23
+ class FeedForward(nn.Module):
24
+ def __init__(self, d_model, ffn_dim, dropout=0.1):
25
+ super().__init__()
26
+ self.net = nn.Sequential(
27
+ nn.Linear(d_model, ffn_dim),
28
+ nn.GELU(),
29
+ nn.Dropout(dropout),
30
+ nn.Linear(ffn_dim, d_model),
31
+ nn.Dropout(dropout),
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.net(x)
36
+
37
+
38
+ class TriModalTransformerBlock(nn.Module):
39
+ def __init__(self, d_model, n_heads, ffn_dim, dropout=0.1):
40
+ super().__init__()
41
+ self.attn = nn.MultiheadAttention(
42
+ d_model, n_heads, dropout=dropout, batch_first=True
43
+ )
44
+ self.norm1 = nn.LayerNorm(d_model)
45
+ self.norm2 = nn.LayerNorm(d_model)
46
+ self.ffn_vision = FeedForward(d_model, ffn_dim, dropout)
47
+ self.ffn_audio = FeedForward(d_model, ffn_dim, dropout)
48
+ self.ffn_text = FeedForward(d_model, ffn_dim, dropout)
49
+
50
+ def forward(self, x, attn_mask, v_idx, a_idx, t_idx):
51
+ # Shared Attention
52
+ residual = x
53
+ x_norm = self.norm1(x)
54
+ x_attn, attn_weights = self.attn(
55
+ x_norm, x_norm, x_norm, attn_mask=attn_mask,
56
+ need_weights=True, average_attn_weights=False,
57
+ )
58
+ x = residual + x_attn
59
+
60
+ # 3-way Modality-Specific FFN
61
+ residual = x
62
+ x_norm = self.norm2(x)
63
+ v_out = self.ffn_vision(x_norm[:, v_idx, :])
64
+ a_out = self.ffn_audio(x_norm[:, a_idx, :])
65
+ t_out = self.ffn_text(x_norm[:, t_idx, :])
66
+ out = torch.cat([v_out, a_out, t_out], dim=1)
67
+ x = residual + out
68
+
69
+ return x, attn_weights
70
+
71
+
72
+ class TriModalModel(nn.Module):
73
+ def __init__(self, cfg=None):
74
+ super().__init__()
75
+ cfg = cfg or CONFIG
76
+ self.cfg = cfg
77
+ d = cfg["d_model"]
78
+ patch_dim = cfg["patch_size"] ** 2
79
+
80
+ # Embeddings
81
+ self.vision_embed = nn.Linear(patch_dim, d)
82
+ self.audio_proj = nn.Linear(cfg["audio_feat_dim"], d)
83
+ self.text_embed = nn.Embedding(cfg["vocab_size"], d)
84
+ self.vision_norm = nn.LayerNorm(d)
85
+ self.audio_norm = nn.LayerNorm(d)
86
+ self.text_norm = nn.LayerNorm(d)
87
+ self.pos_embed = nn.Embedding(cfg["max_seq_len"], d)
88
+
89
+ # Transformer
90
+ self.blocks = nn.ModuleList([
91
+ TriModalTransformerBlock(d, cfg["n_heads"], cfg["ffn_dim"], cfg["dropout"])
92
+ for _ in range(cfg["n_layers"])
93
+ ])
94
+ self.final_norm = nn.LayerNorm(d)
95
+
96
+ # Heads
97
+ self.vision_head = nn.Linear(d, patch_dim)
98
+ self.audio_head = nn.Linear(d, cfg["audio_feat_dim"])
99
+ self.text_head = nn.Linear(d, cfg["vocab_size"])
100
+
101
+ self._init_weights()
102
+
103
+ def _init_weights(self):
104
+ for m in self.modules():
105
+ if isinstance(m, nn.Linear):
106
+ nn.init.normal_(m.weight, std=0.02)
107
+ if m.bias is not None:
108
+ nn.init.zeros_(m.bias)
109
+ elif isinstance(m, nn.Embedding):
110
+ nn.init.normal_(m.weight, std=0.02)
111
+ elif isinstance(m, nn.LayerNorm):
112
+ nn.init.ones_(m.weight)
113
+ nn.init.zeros_(m.bias)
114
+
115
+ def forward(self, vision_patches, audio_features, text_tokens, return_attn=False):
116
+ """
117
+ vision_patches: (B, N_v, patch_dim)
118
+ audio_features: (B, N_a, 768)
119
+ text_tokens: (B, N_t)
120
+ """
121
+ B = text_tokens.size(0)
122
+ N_v = vision_patches.size(1)
123
+ N_a = audio_features.size(1)
124
+ N_t = text_tokens.size(1)
125
+ N = N_v + N_a + N_t
126
+ device = text_tokens.device
127
+
128
+ # Embed
129
+ v_emb = self.vision_norm(self.vision_embed(vision_patches))
130
+ a_emb = self.audio_norm(self.audio_proj(audio_features))
131
+ t_emb = self.text_norm(self.text_embed(text_tokens))
132
+
133
+ # Concat: [vision | audio | text]
134
+ x = torch.cat([v_emb, a_emb, t_emb], dim=1)
135
+ pos = torch.arange(N, device=device)
136
+ x = x + self.pos_embed(pos)
137
+
138
+ # Masks
139
+ attn_mask = self._build_attn_mask(N_v, N_a, N_t, device)
140
+ v_idx = torch.arange(0, N_v, device=device)
141
+ a_idx = torch.arange(N_v, N_v + N_a, device=device)
142
+ t_idx = torch.arange(N_v + N_a, N, device=device)
143
+
144
+ # Transformer
145
+ all_attn = []
146
+ for block in self.blocks:
147
+ x, attn_w = block(x, attn_mask, v_idx, a_idx, t_idx)
148
+ if return_attn:
149
+ all_attn.append(attn_w.detach())
150
+
151
+ x = self.final_norm(x)
152
+
153
+ # Heads
154
+ vision_out = self.vision_head(x[:, :N_v, :])
155
+ audio_out = self.audio_head(x[:, N_v:N_v + N_a, :])
156
+ text_out = self.text_head(x[:, N_v + N_a:, :])
157
+
158
+ if return_attn:
159
+ return vision_out, audio_out, text_out, all_attn
160
+ return vision_out, audio_out, text_out
161
+
162
+ def _build_attn_mask(self, N_v, N_a, N_t, device):
163
+ """
164
+ [Vision | Audio | Text] ordering.
165
+ Vision ↔ Audio: Bidirectional (mutual)
166
+ Text → Vision/Audio: allowed
167
+ Vision/Audio → Text: blocked
168
+ Text internal: Causal
169
+ """
170
+ N = N_v + N_a + N_t
171
+ mask = torch.zeros(N, N, device=device)
172
+
173
+ # Text causal mask
174
+ text_start = N_v + N_a
175
+ text_mask = torch.triu(
176
+ torch.ones(N_t, N_t, device=device) * float('-inf'), diagonal=1
177
+ )
178
+ mask[text_start:, text_start:] = text_mask
179
+
180
+ # Vision → Text: blocked
181
+ mask[:N_v, text_start:] = float('-inf')
182
+ # Audio → Text: blocked
183
+ mask[N_v:text_start, text_start:] = float('-inf')
184
+
185
+ return mask
186
+
187
+ def count_params(self):
188
+ total = sum(p.numel() for p in self.parameters())
189
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
190
+ return {"total": total, "trainable": trainable}
191
+
192
+
193
+ if __name__ == "__main__":
194
+ model = TriModalModel(CONFIG)
195
+ params = model.count_params()
196
+ print(f"Parameters: {params['total']:,} ({params['total']/1e6:.1f}M)")
197
+
198
+ B = 4
199
+ N_v, N_a, N_t = 80, 200, 128
200
+ patch_dim = CONFIG["patch_size"] ** 2
201
+ v = torch.randn(B, N_v, patch_dim)
202
+ a = torch.randn(B, N_a, CONFIG["audio_feat_dim"])
203
+ t = torch.randint(0, CONFIG["vocab_size"], (B, N_t))
204
+
205
+ v_out, a_out, t_out = model(v, a, t)
206
+ print(f"Vision out: {v_out.shape}") # (4, 80, 256)
207
+ print(f"Audio out: {a_out.shape}") # (4, 200, 768)
208
+ print(f"Text out: {t_out.shape}") # (4, 128, 10000)
209
+ print("Forward pass OK")