Update scripts/model_v4.py
Browse files- scripts/model_v4.py +24 -21
scripts/model_v4.py
CHANGED
|
@@ -724,17 +724,20 @@ class Attention(nn.Module):
|
|
| 724 |
q = q * mod.unsqueeze(-1) # [B, heads, N, head_dim]
|
| 725 |
k = k * mod.unsqueeze(-1)
|
| 726 |
|
| 727 |
-
# Compute attention
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
# === Sol Temperature Scaling ===
|
| 731 |
if sol_temperature is not None:
|
| 732 |
-
# temperature
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
out = out.transpose(1, 2).reshape(B, N, -1)
|
| 739 |
|
| 740 |
return self.out_proj(out)
|
|
@@ -817,19 +820,19 @@ class JointAttention(nn.Module):
|
|
| 817 |
k = torch.cat([txt_k, img_k], dim=2)
|
| 818 |
v = torch.cat([txt_v, img_v], dim=2)
|
| 819 |
|
| 820 |
-
# Text attention (
|
| 821 |
-
|
| 822 |
-
txt_attn = F.softmax(txt_scores, dim=-1)
|
| 823 |
-
txt_out = torch.matmul(txt_attn, v)
|
| 824 |
txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
|
| 825 |
|
| 826 |
-
# Image attention (Sol temperature
|
| 827 |
-
img_scores = torch.matmul(img_q, k.transpose(-2, -1)) * self.scale
|
| 828 |
if sol_temperature is not None:
|
| 829 |
-
temp = sol_temperature.
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
|
|
|
|
|
|
|
|
|
| 833 |
img_out = img_out.transpose(1, 2).reshape(B, N, -1)
|
| 834 |
|
| 835 |
return self.txt_out(txt_out), self.img_out(img_out)
|
|
@@ -1562,4 +1565,4 @@ def test_model():
|
|
| 1562 |
|
| 1563 |
|
| 1564 |
#if __name__ == "__main__":
|
| 1565 |
-
#
|
|
|
|
| 724 |
q = q * mod.unsqueeze(-1) # [B, heads, N, head_dim]
|
| 725 |
k = k * mod.unsqueeze(-1)
|
| 726 |
|
| 727 |
+
# === Compute attention with SDPA (Flash Attention) ===
|
| 728 |
+
# Sol temperature is applied via scale modification
|
|
|
|
|
|
|
| 729 |
if sol_temperature is not None:
|
| 730 |
+
# Average temperature across heads for SDPA scale
|
| 731 |
+
# temperature: [B, num_heads] → scalar per sample (SDPA limitation)
|
| 732 |
+
temp = sol_temperature.mean(dim=1, keepdim=True).clamp(min=0.1) # [B, 1]
|
| 733 |
+
effective_scale = self.scale / temp.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1, 1]
|
| 734 |
+
# Pre-scale Q instead of post-scale scores (mathematically equivalent)
|
| 735 |
+
q = q * (effective_scale.sqrt())
|
| 736 |
+
k = k * (effective_scale.sqrt())
|
| 737 |
+
out = F.scaled_dot_product_attention(q, k, v, scale=1.0)
|
| 738 |
+
else:
|
| 739 |
+
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
| 740 |
+
|
| 741 |
out = out.transpose(1, 2).reshape(B, N, -1)
|
| 742 |
|
| 743 |
return self.out_proj(out)
|
|
|
|
| 820 |
k = torch.cat([txt_k, img_k], dim=2)
|
| 821 |
v = torch.cat([txt_v, img_v], dim=2)
|
| 822 |
|
| 823 |
+
# Text attention with SDPA (no Sol modulation)
|
| 824 |
+
txt_out = F.scaled_dot_product_attention(txt_q, k, v, scale=self.scale)
|
|
|
|
|
|
|
| 825 |
txt_out = txt_out.transpose(1, 2).reshape(B, L, -1)
|
| 826 |
|
| 827 |
+
# Image attention with SDPA (Sol temperature via scale modification)
|
|
|
|
| 828 |
if sol_temperature is not None:
|
| 829 |
+
temp = sol_temperature.mean(dim=1, keepdim=True).clamp(min=0.1)
|
| 830 |
+
effective_scale = self.scale / temp.unsqueeze(-1).unsqueeze(-1)
|
| 831 |
+
img_q_scaled = img_q * (effective_scale.sqrt())
|
| 832 |
+
k_scaled = k * (effective_scale.sqrt())
|
| 833 |
+
img_out = F.scaled_dot_product_attention(img_q_scaled, k_scaled, v, scale=1.0)
|
| 834 |
+
else:
|
| 835 |
+
img_out = F.scaled_dot_product_attention(img_q, k, v, scale=self.scale)
|
| 836 |
img_out = img_out.transpose(1, 2).reshape(B, N, -1)
|
| 837 |
|
| 838 |
return self.txt_out(txt_out), self.img_out(img_out)
|
|
|
|
| 1565 |
|
| 1566 |
|
| 1567 |
#if __name__ == "__main__":
|
| 1568 |
+
# test_model()
|