| """ |
| MINDI 1.5 Vision-Coder β Vision-Language Fusion Layer |
| |
| Prepends projected visual tokens (256 Γ 3584) to text token embeddings |
| and extends the attention mask accordingly. Uses Linear + LayerNorm |
| for the visual projection gate. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class VisionLanguageFusion(nn.Module): |
| """ |
| Fuses visual and text embeddings by prepending visual tokens. |
| |
| Pipeline: |
| 1. visual_tokens (batch, 256, 3584) β Linear β LayerNorm |
| 2. Prepend to text_embeds (batch, seq_len, 3584) |
| 3. Extend attention_mask to cover the extra 256 visual positions |
| |
| All trainable parameters live in the gate projection + LayerNorm. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int = 3584, |
| num_visual_tokens: int = 256, |
| ) -> None: |
| """ |
| Initialize the fusion layer. |
| |
| Args: |
| hidden_size: Dimension of both visual and text embeddings (must match). |
| num_visual_tokens: Number of visual tokens prepended (default 256). |
| """ |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_visual_tokens = num_visual_tokens |
|
|
| |
| self.gate_proj = nn.Linear(hidden_size, hidden_size) |
| self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
| |
| |
| self.text_gate = nn.Parameter(torch.zeros(1)) |
|
|
| def forward( |
| self, |
| text_embeds: torch.Tensor, |
| visual_tokens: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| Fuse visual tokens into text embeddings. |
| |
| Args: |
| text_embeds: Text token embeddings (batch, seq_len, hidden_size). |
| visual_tokens: Projected visual tokens (batch, 256, hidden_size), or None |
| for text-only inputs. |
| attention_mask: Text attention mask (batch, seq_len), or None. |
| |
| Returns: |
| fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged. |
| fused_mask: Extended attention mask, or None if input mask was None. |
| """ |
| |
| |
| |
| |
| if visual_tokens is None: |
| alpha = torch.sigmoid(self.text_gate) |
| transformed = self.layer_norm(self.gate_proj(text_embeds)) |
| fused_embeds = text_embeds + alpha * (transformed - text_embeds) |
| return fused_embeds, attention_mask |
|
|
| batch_size = text_embeds.shape[0] |
| v_batch = visual_tokens.shape[0] |
|
|
| |
| if v_batch == 1 and batch_size > 1: |
| visual_tokens = visual_tokens.expand(batch_size, -1, -1) |
|
|
| |
| gated_visual = self.gate_proj(visual_tokens) |
| gated_visual = self.layer_norm(gated_visual) |
|
|
| |
| fused_embeds = torch.cat([gated_visual, text_embeds], dim=1) |
|
|
| |
| fused_mask = self._extend_attention_mask(attention_mask, batch_size, text_embeds.device) |
|
|
| return fused_embeds, fused_mask |
|
|
| def _extend_attention_mask( |
| self, |
| attention_mask: Optional[torch.Tensor], |
| batch_size: int, |
| device: torch.device, |
| ) -> Optional[torch.Tensor]: |
| """ |
| Extend attention mask to include visual token positions (all attended). |
| |
| Args: |
| attention_mask: Original text mask (batch, seq_len) or None. |
| batch_size: Current batch size. |
| device: Target device. |
| |
| Returns: |
| Extended mask (batch, 256 + seq_len) or None. |
| """ |
| if attention_mask is None: |
| return None |
|
|
| |
| visual_mask = torch.ones( |
| batch_size, |
| self.num_visual_tokens, |
| dtype=attention_mask.dtype, |
| device=device, |
| ) |
| return torch.cat([visual_mask, attention_mask], dim=1) |
|
|
| def get_trainable_params(self) -> dict: |
| """ |
| Count trainable parameters in the fusion layer. |
| |
| Returns: |
| Dictionary with 'trainable', 'total', and 'trainable_pct'. |
| """ |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in self.parameters()) |
| pct = 100.0 * trainable / total if total > 0 else 0.0 |
| return { |
| "trainable": trainable, |
| "total": total, |
| "trainable_pct": round(pct, 4), |
| } |
|
|
| def extra_repr(self) -> str: |
| return ( |
| f"hidden_size={self.hidden_size}, " |
| f"num_visual_tokens={self.num_visual_tokens}" |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| print("=" * 60) |
| print(" MINDI 1.5 β Fusion Layer Test") |
| print("=" * 60) |
| print() |
|
|
| BATCH = 2 |
| SEQ_LEN = 128 |
| HIDDEN = 4096 |
| N_VIS = 256 |
|
|
| fusion = VisionLanguageFusion(hidden_size=HIDDEN, num_visual_tokens=N_VIS) |
| print(f" Fusion layer:\n {fusion}\n") |
|
|
| |
| print(" Test 1: Vision + Text fusion") |
| text_embeds = torch.randn(BATCH, SEQ_LEN, HIDDEN) |
| visual_tokens = torch.randn(BATCH, N_VIS, HIDDEN) |
| attention_mask = torch.ones(BATCH, SEQ_LEN, dtype=torch.long) |
|
|
| fused_embeds, fused_mask = fusion(text_embeds, visual_tokens, attention_mask) |
|
|
| expected_seq = N_VIS + SEQ_LEN |
| assert fused_embeds.shape == (BATCH, expected_seq, HIDDEN), \ |
| f"Expected ({BATCH}, {expected_seq}, {HIDDEN}), got {fused_embeds.shape}" |
| assert fused_mask is not None and fused_mask.shape == (BATCH, expected_seq), \ |
| f"Expected mask ({BATCH}, {expected_seq}), got {fused_mask.shape}" |
| print(f" fused_embeds: {fused_embeds.shape} β") |
| print(f" fused_mask: {fused_mask.shape} β") |
|
|
| |
| print("\n Test 2: Text-only (no vision)") |
| text_only, mask_only = fusion(text_embeds, None, attention_mask) |
| assert text_only.shape == (BATCH, SEQ_LEN, HIDDEN) |
| assert mask_only is not None and mask_only.shape == (BATCH, SEQ_LEN) |
| print(f" text_only: {text_only.shape} β") |
| print(f" mask_only: {mask_only.shape} β") |
|
|
| |
| print("\n Test 3: Vision fusion without attention mask") |
| fused_no_mask, none_mask = fusion(text_embeds, visual_tokens, None) |
| assert fused_no_mask.shape == (BATCH, expected_seq, HIDDEN) |
| assert none_mask is None |
| print(f" fused_embeds: {fused_no_mask.shape} β") |
| print(f" fused_mask: None β") |
|
|
| |
| print("\n Test 4: Single-image broadcast to batch") |
| single_visual = torch.randn(1, N_VIS, HIDDEN) |
| fused_bc, mask_bc = fusion(text_embeds, single_visual, attention_mask) |
| assert fused_bc.shape == (BATCH, expected_seq, HIDDEN) |
| print(f" fused_embeds: {fused_bc.shape} β (broadcast 1 β {BATCH})") |
|
|
| |
| print("\n Test 5: Parameter counts") |
| info = fusion.get_trainable_params() |
| |
| |
| expected_params = HIDDEN * HIDDEN + HIDDEN + HIDDEN + HIDDEN |
| assert info["trainable"] == expected_params, \ |
| f"Expected {expected_params}, got {info['trainable']}" |
| print(f" Trainable: {info['trainable']:,}") |
| print(f" Total: {info['total']:,}") |
| print(f" Pct: {info['trainable_pct']}%") |
|
|
| |
| print("\n Test 6: Gradient flow through fusion") |
| fusion.zero_grad() |
| fused_embeds, _ = fusion(text_embeds, visual_tokens, attention_mask) |
| loss = fused_embeds.sum() |
| loss.backward() |
| assert fusion.gate_proj.weight.grad is not None, "No gradient on gate_proj!" |
| assert fusion.layer_norm.weight.grad is not None, "No gradient on layer_norm!" |
| print(" gate_proj gradient: β") |
| print(" layer_norm gradient: β") |
|
|
| print("\n β All fusion layer tests passed!") |
| print("=" * 60) |
|
|