File size: 9,481 Bytes
2ff5c54
 
 
691fc84
2ff5c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691fc84
 
2ff5c54
 
 
 
 
 
 
691fc84
2ff5c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e9835e
 
 
 
2ff5c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e9835e
 
 
 
2ff5c54
4e9835e
 
 
 
2ff5c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""
MINDI 1.5 Vision-Coder β€” Vision-Language Fusion Layer

Prepends projected visual tokens (256 Γ— 3584) to text token embeddings
and extends the attention mask accordingly.  Uses Linear + LayerNorm
for the visual projection gate.
"""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn


class VisionLanguageFusion(nn.Module):
    """
    Fuses visual and text embeddings by prepending visual tokens.

    Pipeline:
        1. visual_tokens (batch, 256, 3584) β†’ Linear β†’ LayerNorm
        2. Prepend to text_embeds (batch, seq_len, 3584)
        3. Extend attention_mask to cover the extra 256 visual positions

    All trainable parameters live in the gate projection + LayerNorm.
    """

    def __init__(
        self,
        hidden_size: int = 3584,
        num_visual_tokens: int = 256,
    ) -> None:
        """
        Initialize the fusion layer.

        Args:
            hidden_size: Dimension of both visual and text embeddings (must match).
            num_visual_tokens: Number of visual tokens prepended (default 256).
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.num_visual_tokens = num_visual_tokens

        # Gate projection: Linear + LayerNorm to align visual features
        self.gate_proj = nn.Linear(hidden_size, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)

        # Text-only residual gate (learnable scalar, starts at 0 so text path
        # is identity at init, then gradually blends in the fusion transform)
        self.text_gate = nn.Parameter(torch.zeros(1))

    def forward(
        self,
        text_embeds: torch.Tensor,
        visual_tokens: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Fuse visual tokens into text embeddings.

        Args:
            text_embeds: Text token embeddings (batch, seq_len, hidden_size).
            visual_tokens: Projected visual tokens (batch, 256, hidden_size), or None
                           for text-only inputs.
            attention_mask: Text attention mask (batch, seq_len), or None.

        Returns:
            fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
            fused_mask: Extended attention mask, or None if input mask was None.
        """
        # Text-only path β€” apply a learnable residual gate through the
        # fusion parameters so gradients can flow to fusion even without images.
        # At init text_gate=0 β†’ sigmoid(0)=0.5, but the residual structure
        # means the output β‰ˆ text_embeds until the gate is trained.
        if visual_tokens is None:
            alpha = torch.sigmoid(self.text_gate)
            transformed = self.layer_norm(self.gate_proj(text_embeds))
            fused_embeds = text_embeds + alpha * (transformed - text_embeds)
            return fused_embeds, attention_mask

        batch_size = text_embeds.shape[0]
        v_batch = visual_tokens.shape[0]

        # Handle batch size mismatch (single image broadcast to batch)
        if v_batch == 1 and batch_size > 1:
            visual_tokens = visual_tokens.expand(batch_size, -1, -1)

        # Gate projection + LayerNorm
        gated_visual = self.gate_proj(visual_tokens)   # (batch, 256, hidden_size)
        gated_visual = self.layer_norm(gated_visual)    # (batch, 256, hidden_size)

        # Prepend visual tokens to text embeddings
        fused_embeds = torch.cat([gated_visual, text_embeds], dim=1)

        # Extend attention mask
        fused_mask = self._extend_attention_mask(attention_mask, batch_size, text_embeds.device)

        return fused_embeds, fused_mask

    def _extend_attention_mask(
        self,
        attention_mask: Optional[torch.Tensor],
        batch_size: int,
        device: torch.device,
    ) -> Optional[torch.Tensor]:
        """
        Extend attention mask to include visual token positions (all attended).

        Args:
            attention_mask: Original text mask (batch, seq_len) or None.
            batch_size: Current batch size.
            device: Target device.

        Returns:
            Extended mask (batch, 256 + seq_len) or None.
        """
        if attention_mask is None:
            return None

        # Visual tokens are always fully attended
        visual_mask = torch.ones(
            batch_size,
            self.num_visual_tokens,
            dtype=attention_mask.dtype,
            device=device,
        )
        return torch.cat([visual_mask, attention_mask], dim=1)

    def get_trainable_params(self) -> dict:
        """
        Count trainable parameters in the fusion layer.

        Returns:
            Dictionary with 'trainable', 'total', and 'trainable_pct'.
        """
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        pct = 100.0 * trainable / total if total > 0 else 0.0
        return {
            "trainable": trainable,
            "total": total,
            "trainable_pct": round(pct, 4),
        }

    def extra_repr(self) -> str:
        return (
            f"hidden_size={self.hidden_size}, "
            f"num_visual_tokens={self.num_visual_tokens}"
        )


# ── Test block ────────────────────────────────────────────────────────
if __name__ == "__main__":
    print("=" * 60)
    print("  MINDI 1.5 β€” Fusion Layer Test")
    print("=" * 60)
    print()

    BATCH = 2
    SEQ_LEN = 128
    HIDDEN = 4096
    N_VIS = 256

    fusion = VisionLanguageFusion(hidden_size=HIDDEN, num_visual_tokens=N_VIS)
    print(f"  Fusion layer:\n  {fusion}\n")

    # ── Test 1: Vision + Text fusion ─────────────────────────────
    print("  Test 1: Vision + Text fusion")
    text_embeds = torch.randn(BATCH, SEQ_LEN, HIDDEN)
    visual_tokens = torch.randn(BATCH, N_VIS, HIDDEN)
    attention_mask = torch.ones(BATCH, SEQ_LEN, dtype=torch.long)

    fused_embeds, fused_mask = fusion(text_embeds, visual_tokens, attention_mask)

    expected_seq = N_VIS + SEQ_LEN  # 256 + 128 = 384
    assert fused_embeds.shape == (BATCH, expected_seq, HIDDEN), \
        f"Expected ({BATCH}, {expected_seq}, {HIDDEN}), got {fused_embeds.shape}"
    assert fused_mask is not None and fused_mask.shape == (BATCH, expected_seq), \
        f"Expected mask ({BATCH}, {expected_seq}), got {fused_mask.shape}"
    print(f"    fused_embeds: {fused_embeds.shape} βœ“")
    print(f"    fused_mask:   {fused_mask.shape} βœ“")

    # ── Test 2: Text-only (no vision) ────────────────────────────
    print("\n  Test 2: Text-only (no vision)")
    text_only, mask_only = fusion(text_embeds, None, attention_mask)
    assert text_only.shape == (BATCH, SEQ_LEN, HIDDEN)
    assert mask_only is not None and mask_only.shape == (BATCH, SEQ_LEN)
    print(f"    text_only:  {text_only.shape} βœ“")
    print(f"    mask_only:  {mask_only.shape} βœ“")

    # ── Test 3: No attention mask ────────────────────────────────
    print("\n  Test 3: Vision fusion without attention mask")
    fused_no_mask, none_mask = fusion(text_embeds, visual_tokens, None)
    assert fused_no_mask.shape == (BATCH, expected_seq, HIDDEN)
    assert none_mask is None
    print(f"    fused_embeds: {fused_no_mask.shape} βœ“")
    print(f"    fused_mask:   None βœ“")

    # ── Test 4: Single-image broadcast ───────────────────────────
    print("\n  Test 4: Single-image broadcast to batch")
    single_visual = torch.randn(1, N_VIS, HIDDEN)
    fused_bc, mask_bc = fusion(text_embeds, single_visual, attention_mask)
    assert fused_bc.shape == (BATCH, expected_seq, HIDDEN)
    print(f"    fused_embeds: {fused_bc.shape} βœ“ (broadcast 1 β†’ {BATCH})")

    # ── Test 5: Trainable params ─────────────────────────────────
    print("\n  Test 5: Parameter counts")
    info = fusion.get_trainable_params()
    # gate_proj: 4096*4096 + 4096 = 16,781,312
    # layer_norm: 4096 + 4096 = 8,192
    expected_params = HIDDEN * HIDDEN + HIDDEN + HIDDEN + HIDDEN  # Linear(w+b) + LN(w+b)
    assert info["trainable"] == expected_params, \
        f"Expected {expected_params}, got {info['trainable']}"
    print(f"    Trainable: {info['trainable']:,}")
    print(f"    Total:     {info['total']:,}")
    print(f"    Pct:       {info['trainable_pct']}%")

    # ── Test 6: Gradient flow ────────────────────────────────────
    print("\n  Test 6: Gradient flow through fusion")
    fusion.zero_grad()
    fused_embeds, _ = fusion(text_embeds, visual_tokens, attention_mask)
    loss = fused_embeds.sum()
    loss.backward()
    assert fusion.gate_proj.weight.grad is not None, "No gradient on gate_proj!"
    assert fusion.layer_norm.weight.grad is not None, "No gradient on layer_norm!"
    print("    gate_proj gradient:  βœ“")
    print("    layer_norm gradient: βœ“")

    print("\n  βœ“ All fusion layer tests passed!")
    print("=" * 60)