2264K's picture
Upload model.py with huggingface_hub
3f1b0bf verified
"""
Phase 2-A Toy PoC: 3-way Modality-Specific FFN (Vision + Audio + Text)
Shared Attention + ffn_vision / ffn_audio / ffn_text
"""
import torch
import torch.nn as nn
CONFIG = {
"d_model": 256,
"n_heads": 4,
"ffn_dim": 512,
"n_layers": 6,
"vocab_size": 10000,
"patch_size": 16,
"max_seq_len": 512,
"dropout": 0.1,
"audio_feat_dim": 768,
}
class FeedForward(nn.Module):
def __init__(self, d_model, ffn_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, ffn_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ffn_dim, d_model),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class TriModalTransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, ffn_dim, dropout=0.1):
super().__init__()
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn_vision = FeedForward(d_model, ffn_dim, dropout)
self.ffn_audio = FeedForward(d_model, ffn_dim, dropout)
self.ffn_text = FeedForward(d_model, ffn_dim, dropout)
def forward(self, x, attn_mask, v_idx, a_idx, t_idx):
# Shared Attention
residual = x
x_norm = self.norm1(x)
x_attn, attn_weights = self.attn(
x_norm, x_norm, x_norm, attn_mask=attn_mask,
need_weights=True, average_attn_weights=False,
)
x = residual + x_attn
# 3-way Modality-Specific FFN
residual = x
x_norm = self.norm2(x)
v_out = self.ffn_vision(x_norm[:, v_idx, :])
a_out = self.ffn_audio(x_norm[:, a_idx, :])
t_out = self.ffn_text(x_norm[:, t_idx, :])
out = torch.cat([v_out, a_out, t_out], dim=1)
x = residual + out
return x, attn_weights
class TriModalModel(nn.Module):
def __init__(self, cfg=None):
super().__init__()
cfg = cfg or CONFIG
self.cfg = cfg
d = cfg["d_model"]
patch_dim = cfg["patch_size"] ** 2
# Embeddings
self.vision_embed = nn.Linear(patch_dim, d)
self.audio_proj = nn.Linear(cfg["audio_feat_dim"], d)
self.text_embed = nn.Embedding(cfg["vocab_size"], d)
self.vision_norm = nn.LayerNorm(d)
self.audio_norm = nn.LayerNorm(d)
self.text_norm = nn.LayerNorm(d)
self.pos_embed = nn.Embedding(cfg["max_seq_len"], d)
# Transformer
self.blocks = nn.ModuleList([
TriModalTransformerBlock(d, cfg["n_heads"], cfg["ffn_dim"], cfg["dropout"])
for _ in range(cfg["n_layers"])
])
self.final_norm = nn.LayerNorm(d)
# Heads
self.vision_head = nn.Linear(d, patch_dim)
self.audio_head = nn.Linear(d, cfg["audio_feat_dim"])
self.text_head = nn.Linear(d, cfg["vocab_size"])
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, vision_patches, audio_features, text_tokens, return_attn=False):
"""
vision_patches: (B, N_v, patch_dim)
audio_features: (B, N_a, 768)
text_tokens: (B, N_t)
"""
B = text_tokens.size(0)
N_v = vision_patches.size(1)
N_a = audio_features.size(1)
N_t = text_tokens.size(1)
N = N_v + N_a + N_t
device = text_tokens.device
# Embed
v_emb = self.vision_norm(self.vision_embed(vision_patches))
a_emb = self.audio_norm(self.audio_proj(audio_features))
t_emb = self.text_norm(self.text_embed(text_tokens))
# Concat: [vision | audio | text]
x = torch.cat([v_emb, a_emb, t_emb], dim=1)
pos = torch.arange(N, device=device)
x = x + self.pos_embed(pos)
# Masks
attn_mask = self._build_attn_mask(N_v, N_a, N_t, device)
v_idx = torch.arange(0, N_v, device=device)
a_idx = torch.arange(N_v, N_v + N_a, device=device)
t_idx = torch.arange(N_v + N_a, N, device=device)
# Transformer
all_attn = []
for block in self.blocks:
x, attn_w = block(x, attn_mask, v_idx, a_idx, t_idx)
if return_attn:
all_attn.append(attn_w.detach())
x = self.final_norm(x)
# Heads
vision_out = self.vision_head(x[:, :N_v, :])
audio_out = self.audio_head(x[:, N_v:N_v + N_a, :])
text_out = self.text_head(x[:, N_v + N_a:, :])
if return_attn:
return vision_out, audio_out, text_out, all_attn
return vision_out, audio_out, text_out
def _build_attn_mask(self, N_v, N_a, N_t, device):
"""
[Vision | Audio | Text] ordering.
Vision ↔ Audio: Bidirectional (mutual)
Text → Vision/Audio: allowed
Vision/Audio → Text: blocked
Text internal: Causal
"""
N = N_v + N_a + N_t
mask = torch.zeros(N, N, device=device)
# Text causal mask
text_start = N_v + N_a
text_mask = torch.triu(
torch.ones(N_t, N_t, device=device) * float('-inf'), diagonal=1
)
mask[text_start:, text_start:] = text_mask
# Vision → Text: blocked
mask[:N_v, text_start:] = float('-inf')
# Audio → Text: blocked
mask[N_v:text_start, text_start:] = float('-inf')
return mask
def count_params(self):
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {"total": total, "trainable": trainable}
if __name__ == "__main__":
model = TriModalModel(CONFIG)
params = model.count_params()
print(f"Parameters: {params['total']:,} ({params['total']/1e6:.1f}M)")
B = 4
N_v, N_a, N_t = 80, 200, 128
patch_dim = CONFIG["patch_size"] ** 2
v = torch.randn(B, N_v, patch_dim)
a = torch.randn(B, N_a, CONFIG["audio_feat_dim"])
t = torch.randint(0, CONFIG["vocab_size"], (B, N_t))
v_out, a_out, t_out = model(v, a, t)
print(f"Vision out: {v_out.shape}") # (4, 80, 256)
print(f"Audio out: {a_out.shape}") # (4, 200, 768)
print(f"Text out: {t_out.shape}") # (4, 128, 10000)
print("Forward pass OK")