Faaz
Fix Phase 2: fusion layer processes text-only via learnable residual gate for gradient flow
4e9835e | """ | |
| MINDI 1.5 Vision-Coder β Vision-Language Fusion Layer | |
| Prepends projected visual tokens (256 Γ 4096) to text token embeddings | |
| and extends the attention mask accordingly. Uses Linear + LayerNorm | |
| for the visual projection gate. | |
| """ | |
| from __future__ import annotations | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| class VisionLanguageFusion(nn.Module): | |
| """ | |
| Fuses visual and text embeddings by prepending visual tokens. | |
| Pipeline: | |
| 1. visual_tokens (batch, 256, 4096) β Linear β LayerNorm | |
| 2. Prepend to text_embeds (batch, seq_len, 4096) | |
| 3. Extend attention_mask to cover the extra 256 visual positions | |
| All trainable parameters live in the gate projection + LayerNorm. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int = 4096, | |
| num_visual_tokens: int = 256, | |
| ) -> None: | |
| """ | |
| Initialize the fusion layer. | |
| Args: | |
| hidden_size: Dimension of both visual and text embeddings (must match). | |
| num_visual_tokens: Number of visual tokens prepended (default 256). | |
| """ | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_visual_tokens = num_visual_tokens | |
| # Gate projection: Linear + LayerNorm to align visual features | |
| self.gate_proj = nn.Linear(hidden_size, hidden_size) | |
| self.layer_norm = nn.LayerNorm(hidden_size) | |
| # Text-only residual gate (learnable scalar, starts at 0 so text path | |
| # is identity at init, then gradually blends in the fusion transform) | |
| self.text_gate = nn.Parameter(torch.zeros(1)) | |
| def forward( | |
| self, | |
| text_embeds: torch.Tensor, | |
| visual_tokens: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Fuse visual tokens into text embeddings. | |
| Args: | |
| text_embeds: Text token embeddings (batch, seq_len, hidden_size). | |
| visual_tokens: Projected visual tokens (batch, 256, hidden_size), or None | |
| for text-only inputs. | |
| attention_mask: Text attention mask (batch, seq_len), or None. | |
| Returns: | |
| fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged. | |
| fused_mask: Extended attention mask, or None if input mask was None. | |
| """ | |
| # Text-only path β apply a learnable residual gate through the | |
| # fusion parameters so gradients can flow to fusion even without images. | |
| # At init text_gate=0 β sigmoid(0)=0.5, but the residual structure | |
| # means the output β text_embeds until the gate is trained. | |
| if visual_tokens is None: | |
| alpha = torch.sigmoid(self.text_gate) | |
| transformed = self.layer_norm(self.gate_proj(text_embeds)) | |
| fused_embeds = text_embeds + alpha * (transformed - text_embeds) | |
| return fused_embeds, attention_mask | |
| batch_size = text_embeds.shape[0] | |
| v_batch = visual_tokens.shape[0] | |
| # Handle batch size mismatch (single image broadcast to batch) | |
| if v_batch == 1 and batch_size > 1: | |
| visual_tokens = visual_tokens.expand(batch_size, -1, -1) | |
| # Gate projection + LayerNorm | |
| gated_visual = self.gate_proj(visual_tokens) # (batch, 256, hidden_size) | |
| gated_visual = self.layer_norm(gated_visual) # (batch, 256, hidden_size) | |
| # Prepend visual tokens to text embeddings | |
| fused_embeds = torch.cat([gated_visual, text_embeds], dim=1) | |
| # Extend attention mask | |
| fused_mask = self._extend_attention_mask(attention_mask, batch_size, text_embeds.device) | |
| return fused_embeds, fused_mask | |
| def _extend_attention_mask( | |
| self, | |
| attention_mask: Optional[torch.Tensor], | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> Optional[torch.Tensor]: | |
| """ | |
| Extend attention mask to include visual token positions (all attended). | |
| Args: | |
| attention_mask: Original text mask (batch, seq_len) or None. | |
| batch_size: Current batch size. | |
| device: Target device. | |
| Returns: | |
| Extended mask (batch, 256 + seq_len) or None. | |
| """ | |
| if attention_mask is None: | |
| return None | |
| # Visual tokens are always fully attended | |
| visual_mask = torch.ones( | |
| batch_size, | |
| self.num_visual_tokens, | |
| dtype=attention_mask.dtype, | |
| device=device, | |
| ) | |
| return torch.cat([visual_mask, attention_mask], dim=1) | |
| def get_trainable_params(self) -> dict: | |
| """ | |
| Count trainable parameters in the fusion layer. | |
| Returns: | |
| Dictionary with 'trainable', 'total', and 'trainable_pct'. | |
| """ | |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in self.parameters()) | |
| pct = 100.0 * trainable / total if total > 0 else 0.0 | |
| return { | |
| "trainable": trainable, | |
| "total": total, | |
| "trainable_pct": round(pct, 4), | |
| } | |
| def extra_repr(self) -> str: | |
| return ( | |
| f"hidden_size={self.hidden_size}, " | |
| f"num_visual_tokens={self.num_visual_tokens}" | |
| ) | |
| # ββ Test block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print(" MINDI 1.5 β Fusion Layer Test") | |
| print("=" * 60) | |
| print() | |
| BATCH = 2 | |
| SEQ_LEN = 128 | |
| HIDDEN = 4096 | |
| N_VIS = 256 | |
| fusion = VisionLanguageFusion(hidden_size=HIDDEN, num_visual_tokens=N_VIS) | |
| print(f" Fusion layer:\n {fusion}\n") | |
| # ββ Test 1: Vision + Text fusion βββββββββββββββββββββββββββββ | |
| print(" Test 1: Vision + Text fusion") | |
| text_embeds = torch.randn(BATCH, SEQ_LEN, HIDDEN) | |
| visual_tokens = torch.randn(BATCH, N_VIS, HIDDEN) | |
| attention_mask = torch.ones(BATCH, SEQ_LEN, dtype=torch.long) | |
| fused_embeds, fused_mask = fusion(text_embeds, visual_tokens, attention_mask) | |
| expected_seq = N_VIS + SEQ_LEN # 256 + 128 = 384 | |
| assert fused_embeds.shape == (BATCH, expected_seq, HIDDEN), \ | |
| f"Expected ({BATCH}, {expected_seq}, {HIDDEN}), got {fused_embeds.shape}" | |
| assert fused_mask is not None and fused_mask.shape == (BATCH, expected_seq), \ | |
| f"Expected mask ({BATCH}, {expected_seq}), got {fused_mask.shape}" | |
| print(f" fused_embeds: {fused_embeds.shape} β") | |
| print(f" fused_mask: {fused_mask.shape} β") | |
| # ββ Test 2: Text-only (no vision) ββββββββββββββββββββββββββββ | |
| print("\n Test 2: Text-only (no vision)") | |
| text_only, mask_only = fusion(text_embeds, None, attention_mask) | |
| assert text_only.shape == (BATCH, SEQ_LEN, HIDDEN) | |
| assert mask_only is not None and mask_only.shape == (BATCH, SEQ_LEN) | |
| print(f" text_only: {text_only.shape} β") | |
| print(f" mask_only: {mask_only.shape} β") | |
| # ββ Test 3: No attention mask ββββββββββββββββββββββββββββββββ | |
| print("\n Test 3: Vision fusion without attention mask") | |
| fused_no_mask, none_mask = fusion(text_embeds, visual_tokens, None) | |
| assert fused_no_mask.shape == (BATCH, expected_seq, HIDDEN) | |
| assert none_mask is None | |
| print(f" fused_embeds: {fused_no_mask.shape} β") | |
| print(f" fused_mask: None β") | |
| # ββ Test 4: Single-image broadcast βββββββββββββββββββββββββββ | |
| print("\n Test 4: Single-image broadcast to batch") | |
| single_visual = torch.randn(1, N_VIS, HIDDEN) | |
| fused_bc, mask_bc = fusion(text_embeds, single_visual, attention_mask) | |
| assert fused_bc.shape == (BATCH, expected_seq, HIDDEN) | |
| print(f" fused_embeds: {fused_bc.shape} β (broadcast 1 β {BATCH})") | |
| # ββ Test 5: Trainable params βββββββββββββββββββββββββββββββββ | |
| print("\n Test 5: Parameter counts") | |
| info = fusion.get_trainable_params() | |
| # gate_proj: 4096*4096 + 4096 = 16,781,312 | |
| # layer_norm: 4096 + 4096 = 8,192 | |
| expected_params = HIDDEN * HIDDEN + HIDDEN + HIDDEN + HIDDEN # Linear(w+b) + LN(w+b) | |
| assert info["trainable"] == expected_params, \ | |
| f"Expected {expected_params}, got {info['trainable']}" | |
| print(f" Trainable: {info['trainable']:,}") | |
| print(f" Total: {info['total']:,}") | |
| print(f" Pct: {info['trainable_pct']}%") | |
| # ββ Test 6: Gradient flow ββββββββββββββββββββββββββββββββββββ | |
| print("\n Test 6: Gradient flow through fusion") | |
| fusion.zero_grad() | |
| fused_embeds, _ = fusion(text_embeds, visual_tokens, attention_mask) | |
| loss = fused_embeds.sum() | |
| loss.backward() | |
| assert fusion.gate_proj.weight.grad is not None, "No gradient on gate_proj!" | |
| assert fusion.layer_norm.weight.grad is not None, "No gradient on layer_norm!" | |
| print(" gate_proj gradient: β") | |
| print(" layer_norm gradient: β") | |
| print("\n β All fusion layer tests passed!") | |
| print("=" * 60) | |