| """
|
| Phase 2-A Toy PoC: 3-way Modality-Specific FFN (Vision + Audio + Text)
|
| Shared Attention + ffn_vision / ffn_audio / ffn_text
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
|
|
| CONFIG = {
|
| "d_model": 256,
|
| "n_heads": 4,
|
| "ffn_dim": 512,
|
| "n_layers": 6,
|
| "vocab_size": 10000,
|
| "patch_size": 16,
|
| "max_seq_len": 512,
|
| "dropout": 0.1,
|
| "audio_feat_dim": 768,
|
| }
|
|
|
|
|
| class FeedForward(nn.Module):
|
| def __init__(self, d_model, ffn_dim, dropout=0.1):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(d_model, ffn_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(ffn_dim, d_model),
|
| nn.Dropout(dropout),
|
| )
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| class TriModalTransformerBlock(nn.Module):
|
| def __init__(self, d_model, n_heads, ffn_dim, dropout=0.1):
|
| super().__init__()
|
| self.attn = nn.MultiheadAttention(
|
| d_model, n_heads, dropout=dropout, batch_first=True
|
| )
|
| self.norm1 = nn.LayerNorm(d_model)
|
| self.norm2 = nn.LayerNorm(d_model)
|
| self.ffn_vision = FeedForward(d_model, ffn_dim, dropout)
|
| self.ffn_audio = FeedForward(d_model, ffn_dim, dropout)
|
| self.ffn_text = FeedForward(d_model, ffn_dim, dropout)
|
|
|
| def forward(self, x, attn_mask, v_idx, a_idx, t_idx):
|
|
|
| residual = x
|
| x_norm = self.norm1(x)
|
| x_attn, attn_weights = self.attn(
|
| x_norm, x_norm, x_norm, attn_mask=attn_mask,
|
| need_weights=True, average_attn_weights=False,
|
| )
|
| x = residual + x_attn
|
|
|
|
|
| residual = x
|
| x_norm = self.norm2(x)
|
| v_out = self.ffn_vision(x_norm[:, v_idx, :])
|
| a_out = self.ffn_audio(x_norm[:, a_idx, :])
|
| t_out = self.ffn_text(x_norm[:, t_idx, :])
|
| out = torch.cat([v_out, a_out, t_out], dim=1)
|
| x = residual + out
|
|
|
| return x, attn_weights
|
|
|
|
|
| class TriModalModel(nn.Module):
|
| def __init__(self, cfg=None):
|
| super().__init__()
|
| cfg = cfg or CONFIG
|
| self.cfg = cfg
|
| d = cfg["d_model"]
|
| patch_dim = cfg["patch_size"] ** 2
|
|
|
|
|
| self.vision_embed = nn.Linear(patch_dim, d)
|
| self.audio_proj = nn.Linear(cfg["audio_feat_dim"], d)
|
| self.text_embed = nn.Embedding(cfg["vocab_size"], d)
|
| self.vision_norm = nn.LayerNorm(d)
|
| self.audio_norm = nn.LayerNorm(d)
|
| self.text_norm = nn.LayerNorm(d)
|
| self.pos_embed = nn.Embedding(cfg["max_seq_len"], d)
|
|
|
|
|
| self.blocks = nn.ModuleList([
|
| TriModalTransformerBlock(d, cfg["n_heads"], cfg["ffn_dim"], cfg["dropout"])
|
| for _ in range(cfg["n_layers"])
|
| ])
|
| self.final_norm = nn.LayerNorm(d)
|
|
|
|
|
| self.vision_head = nn.Linear(d, patch_dim)
|
| self.audio_head = nn.Linear(d, cfg["audio_feat_dim"])
|
| self.text_head = nn.Linear(d, cfg["vocab_size"])
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| for m in self.modules():
|
| if isinstance(m, nn.Linear):
|
| nn.init.normal_(m.weight, std=0.02)
|
| if m.bias is not None:
|
| nn.init.zeros_(m.bias)
|
| elif isinstance(m, nn.Embedding):
|
| nn.init.normal_(m.weight, std=0.02)
|
| elif isinstance(m, nn.LayerNorm):
|
| nn.init.ones_(m.weight)
|
| nn.init.zeros_(m.bias)
|
|
|
| def forward(self, vision_patches, audio_features, text_tokens, return_attn=False):
|
| """
|
| vision_patches: (B, N_v, patch_dim)
|
| audio_features: (B, N_a, 768)
|
| text_tokens: (B, N_t)
|
| """
|
| B = text_tokens.size(0)
|
| N_v = vision_patches.size(1)
|
| N_a = audio_features.size(1)
|
| N_t = text_tokens.size(1)
|
| N = N_v + N_a + N_t
|
| device = text_tokens.device
|
|
|
|
|
| v_emb = self.vision_norm(self.vision_embed(vision_patches))
|
| a_emb = self.audio_norm(self.audio_proj(audio_features))
|
| t_emb = self.text_norm(self.text_embed(text_tokens))
|
|
|
|
|
| x = torch.cat([v_emb, a_emb, t_emb], dim=1)
|
| pos = torch.arange(N, device=device)
|
| x = x + self.pos_embed(pos)
|
|
|
|
|
| attn_mask = self._build_attn_mask(N_v, N_a, N_t, device)
|
| v_idx = torch.arange(0, N_v, device=device)
|
| a_idx = torch.arange(N_v, N_v + N_a, device=device)
|
| t_idx = torch.arange(N_v + N_a, N, device=device)
|
|
|
|
|
| all_attn = []
|
| for block in self.blocks:
|
| x, attn_w = block(x, attn_mask, v_idx, a_idx, t_idx)
|
| if return_attn:
|
| all_attn.append(attn_w.detach())
|
|
|
| x = self.final_norm(x)
|
|
|
|
|
| vision_out = self.vision_head(x[:, :N_v, :])
|
| audio_out = self.audio_head(x[:, N_v:N_v + N_a, :])
|
| text_out = self.text_head(x[:, N_v + N_a:, :])
|
|
|
| if return_attn:
|
| return vision_out, audio_out, text_out, all_attn
|
| return vision_out, audio_out, text_out
|
|
|
| def _build_attn_mask(self, N_v, N_a, N_t, device):
|
| """
|
| [Vision | Audio | Text] ordering.
|
| Vision ↔ Audio: Bidirectional (mutual)
|
| Text → Vision/Audio: allowed
|
| Vision/Audio → Text: blocked
|
| Text internal: Causal
|
| """
|
| N = N_v + N_a + N_t
|
| mask = torch.zeros(N, N, device=device)
|
|
|
|
|
| text_start = N_v + N_a
|
| text_mask = torch.triu(
|
| torch.ones(N_t, N_t, device=device) * float('-inf'), diagonal=1
|
| )
|
| mask[text_start:, text_start:] = text_mask
|
|
|
|
|
| mask[:N_v, text_start:] = float('-inf')
|
|
|
| mask[N_v:text_start, text_start:] = float('-inf')
|
|
|
| return mask
|
|
|
| def count_params(self):
|
| total = sum(p.numel() for p in self.parameters())
|
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| return {"total": total, "trainable": trainable}
|
|
|
|
|
| if __name__ == "__main__":
|
| model = TriModalModel(CONFIG)
|
| params = model.count_params()
|
| print(f"Parameters: {params['total']:,} ({params['total']/1e6:.1f}M)")
|
|
|
| B = 4
|
| N_v, N_a, N_t = 80, 200, 128
|
| patch_dim = CONFIG["patch_size"] ** 2
|
| v = torch.randn(B, N_v, patch_dim)
|
| a = torch.randn(B, N_a, CONFIG["audio_feat_dim"])
|
| t = torch.randint(0, CONFIG["vocab_size"], (B, N_t))
|
|
|
| v_out, a_out, t_out = model(v, a, t)
|
| print(f"Vision out: {v_out.shape}")
|
| print(f"Audio out: {a_out.shape}")
|
| print(f"Text out: {t_out.shape}")
|
| print("Forward pass OK")
|
|
|