Faaz commited on
Commit
4e9835e
·
1 Parent(s): cdc806e

Fix Phase 2: fusion layer processes text-only via learnable residual gate for gradient flow

Browse files
Files changed (1) hide show
  1. src/model/fusion_layer.py +12 -2
src/model/fusion_layer.py CHANGED
@@ -46,6 +46,10 @@ class VisionLanguageFusion(nn.Module):
46
  self.gate_proj = nn.Linear(hidden_size, hidden_size)
47
  self.layer_norm = nn.LayerNorm(hidden_size)
48
 
 
 
 
 
49
  def forward(
50
  self,
51
  text_embeds: torch.Tensor,
@@ -65,9 +69,15 @@ class VisionLanguageFusion(nn.Module):
65
  fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
66
  fused_mask: Extended attention mask, or None if input mask was None.
67
  """
68
- # Text-only path — no vision tokens to fuse
 
 
 
69
  if visual_tokens is None:
70
- return text_embeds, attention_mask
 
 
 
71
 
72
  batch_size = text_embeds.shape[0]
73
  v_batch = visual_tokens.shape[0]
 
46
  self.gate_proj = nn.Linear(hidden_size, hidden_size)
47
  self.layer_norm = nn.LayerNorm(hidden_size)
48
 
49
+ # Text-only residual gate (learnable scalar, starts at 0 so text path
50
+ # is identity at init, then gradually blends in the fusion transform)
51
+ self.text_gate = nn.Parameter(torch.zeros(1))
52
+
53
  def forward(
54
  self,
55
  text_embeds: torch.Tensor,
 
69
  fused_embeds: (batch, 256 + seq_len, hidden_size) if visual, else unchanged.
70
  fused_mask: Extended attention mask, or None if input mask was None.
71
  """
72
+ # Text-only path — apply a learnable residual gate through the
73
+ # fusion parameters so gradients can flow to fusion even without images.
74
+ # At init text_gate=0 → sigmoid(0)=0.5, but the residual structure
75
+ # means the output ≈ text_embeds until the gate is trained.
76
  if visual_tokens is None:
77
+ alpha = torch.sigmoid(self.text_gate)
78
+ transformed = self.layer_norm(self.gate_proj(text_embeds))
79
+ fused_embeds = text_embeds + alpha * (transformed - text_embeds)
80
+ return fused_embeds, attention_mask
81
 
82
  batch_size = text_embeds.shape[0]
83
  v_batch = visual_tokens.shape[0]