Faaz commited on
Commit ·
4e9835e
1
Parent(s): cdc806e
Fix Phase 2: fusion layer processes text-only via learnable residual gate for gradient flow
Browse files- src/model/fusion_layer.py +12 -2
src/model/fusion_layer.py
CHANGED
|
@@ -46,6 +46,10 @@ class VisionLanguageFusion(nn.Module):
|
|
| 46 |
self.gate_proj = nn.Linear(hidden_size, hidden_size)
|
| 47 |
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def forward(
|
| 50 |
self,
|
| 51 |
text_embeds: torch.Tensor,
|
|
@@ -65,9 +69,15 @@ class VisionLanguageFusion(nn.Module):
|
|
| 65 |
fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
|
| 66 |
fused_mask: Extended attention mask, or None if input mask was None.
|
| 67 |
"""
|
| 68 |
-
# Text-only path —
|
|
|
|
|
|
|
|
|
|
| 69 |
if visual_tokens is None:
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
batch_size = text_embeds.shape[0]
|
| 73 |
v_batch = visual_tokens.shape[0]
|
|
|
|
| 46 |
self.gate_proj = nn.Linear(hidden_size, hidden_size)
|
| 47 |
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 48 |
|
| 49 |
+
# Text-only residual gate (learnable scalar, starts at 0 so text path
|
| 50 |
+
# is identity at init, then gradually blends in the fusion transform)
|
| 51 |
+
self.text_gate = nn.Parameter(torch.zeros(1))
|
| 52 |
+
|
| 53 |
def forward(
|
| 54 |
self,
|
| 55 |
text_embeds: torch.Tensor,
|
|
|
|
| 69 |
fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
|
| 70 |
fused_mask: Extended attention mask, or None if input mask was None.
|
| 71 |
"""
|
| 72 |
+
# Text-only path — apply a learnable residual gate through the
|
| 73 |
+
# fusion parameters so gradients can flow to fusion even without images.
|
| 74 |
+
# At init text_gate=0 → sigmoid(0)=0.5, but the residual structure
|
| 75 |
+
# means the output ≈ text_embeds until the gate is trained.
|
| 76 |
if visual_tokens is None:
|
| 77 |
+
alpha = torch.sigmoid(self.text_gate)
|
| 78 |
+
transformed = self.layer_norm(self.gate_proj(text_embeds))
|
| 79 |
+
fused_embeds = text_embeds + alpha * (transformed - text_embeds)
|
| 80 |
+
return fused_embeds, attention_mask
|
| 81 |
|
| 82 |
batch_size = text_embeds.shape[0]
|
| 83 |
v_batch = visual_tokens.shape[0]
|