AbstractPhil commited on
Commit
a5c7471
·
verified ·
1 Parent(s): 1f24ada

Update scripts/model_v4.py

Browse files
Files changed (1) hide show
  1. scripts/model_v4.py +24 -21
scripts/model_v4.py CHANGED
@@ -724,17 +724,20 @@ class Attention(nn.Module):
724
  q = q * mod.unsqueeze(-1) # [B, heads, N, head_dim]
725
  k = k * mod.unsqueeze(-1)
726
 
727
- # Compute attention scores
728
- scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B, heads, N, N]
729
-
730
- # === Sol Temperature Scaling ===
731
  if sol_temperature is not None:
732
- # temperature: [B, num_heads] [B, heads, 1, 1]
733
- temp = sol_temperature.unsqueeze(-1).unsqueeze(-1).clamp(min=0.1)
734
- scores = scores / temp
735
-
736
- attn = F.softmax(scores, dim=-1)
737
- out = torch.matmul(attn, v)
 
 
 
 
 
738
  out = out.transpose(1, 2).reshape(B, N, -1)
739
 
740
  return self.out_proj(out)
@@ -817,19 +820,19 @@ class JointAttention(nn.Module):
817
  k = torch.cat([txt_k, img_k], dim=2)
818
  v = torch.cat([txt_v, img_v], dim=2)
819
 
820
- # Text attention (NO Sol temperature - text is not spatial)
821
- txt_scores = torch.matmul(txt_q, k.transpose(-2, -1)) * self.scale
822
- txt_attn = F.softmax(txt_scores, dim=-1)
823
- txt_out = torch.matmul(txt_attn, v)
824
  txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
825
 
826
- # Image attention (Sol temperature applies here only)
827
- img_scores = torch.matmul(img_q, k.transpose(-2, -1)) * self.scale
828
  if sol_temperature is not None:
829
- temp = sol_temperature.unsqueeze(-1).unsqueeze(-1).clamp(min=0.1)
830
- img_scores = img_scores / temp
831
- img_attn = F.softmax(img_scores, dim=-1)
832
- img_out = torch.matmul(img_attn, v)
 
 
 
833
  img_out = img_out.transpose(1, 2).reshape(B, N, -1)
834
 
835
  return self.txt_out(txt_out), self.img_out(img_out)
@@ -1562,4 +1565,4 @@ def test_model():
1562
 
1563
 
1564
  #if __name__ == "__main__":
1565
- # test_model()
 
724
  q = q * mod.unsqueeze(-1) # [B, heads, N, head_dim]
725
  k = k * mod.unsqueeze(-1)
726
 
727
+ # === Compute attention with SDPA (Flash Attention) ===
728
+ # Sol temperature is applied via scale modification
 
 
729
  if sol_temperature is not None:
730
+ # Average temperature across heads for SDPA scale
731
+ # temperature: [B, num_heads] → scalar per sample (SDPA limitation)
732
+ temp = sol_temperature.mean(dim=1, keepdim=True).clamp(min=0.1) # [B, 1]
733
+ effective_scale = self.scale / temp.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1, 1]
734
+ # Pre-scale Q instead of post-scale scores (mathematically equivalent)
735
+ q = q * (effective_scale.sqrt())
736
+ k = k * (effective_scale.sqrt())
737
+ out = F.scaled_dot_product_attention(q, k, v, scale=1.0)
738
+ else:
739
+ out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
740
+
741
  out = out.transpose(1, 2).reshape(B, N, -1)
742
 
743
  return self.out_proj(out)
 
820
  k = torch.cat([txt_k, img_k], dim=2)
821
  v = torch.cat([txt_v, img_v], dim=2)
822
 
823
+ # Text attention with SDPA (no Sol modulation)
824
+ txt_out = F.scaled_dot_product_attention(txt_q, k, v, scale=self.scale)
 
 
825
  txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
826
 
827
+ # Image attention with SDPA (Sol temperature via scale modification)
 
828
  if sol_temperature is not None:
829
+ temp = sol_temperature.mean(dim=1, keepdim=True).clamp(min=0.1)
830
+ effective_scale = self.scale / temp.unsqueeze(-1).unsqueeze(-1)
831
+ img_q_scaled = img_q * (effective_scale.sqrt())
832
+ k_scaled = k * (effective_scale.sqrt())
833
+ img_out = F.scaled_dot_product_attention(img_q_scaled, k_scaled, v, scale=1.0)
834
+ else:
835
+ img_out = F.scaled_dot_product_attention(img_q, k, v, scale=self.scale)
836
  img_out = img_out.transpose(1, 2).reshape(B, N, -1)
837
 
838
  return self.txt_out(txt_out), self.img_out(img_out)
 
1565
 
1566
 
1567
  #if __name__ == "__main__":
1568
+ # test_model()