Upload 2 files
Browse files- MonoidForCausalLM.py +27 -20
MonoidForCausalLM.py
CHANGED
|
@@ -454,31 +454,38 @@ class MonoidAttention(nn.Module):
|
|
| 454 |
return self.o_proj(o), new_state
|
| 455 |
|
| 456 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 457 |
-
# Inference prefill (use_cache=True, T>1):
|
| 458 |
-
# ๆจ็้ขๅกซๅ
(use_cache=True, T>1):
|
| 459 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 460 |
-
#
|
| 461 |
-
#
|
| 462 |
-
#
|
| 463 |
-
#
|
|
|
|
|
|
|
| 464 |
if use_cache:
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
|
|
|
|
|
|
|
|
|
| 478 |
if monoid_cache is not None:
|
| 479 |
monoid_cache.update(self.layer_idx, final_state)
|
| 480 |
|
| 481 |
-
|
|
|
|
|
|
|
| 482 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 483 |
return self.o_proj(o), final_state
|
| 484 |
|
|
|
|
| 454 |
return self.o_proj(o), new_state
|
| 455 |
|
| 456 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 457 |
+
# Inference prefill (use_cache=True, T>1): parallel scan + readout
|
| 458 |
+
# ๆจ็้ขๅกซๅ
(use_cache=True, T>1): ๅนถ่กๆซๆ + ่ฏปๅบ
|
| 459 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 460 |
+
# Uses the same parallel_scan_with_state as training to leverage
|
| 461 |
+
# Triton CUDA kernel acceleration instead of O(T) Python loop.
|
| 462 |
+
# Memory: O(BยทHยทTยทdยฒ) โ same as training path.
|
| 463 |
+
# ไฝฟ็จไธ่ฎญ็ป็ธๅ็ parallel_scan_with_state ๆฅๅฉ็จ
|
| 464 |
+
# Triton CUDA ๆ ธๅฝๆฐๅ ้, ่้ O(T) ็ Python ๅพช็ฏใ
|
| 465 |
+
# ๅ
ๅญ: O(BยทHยทTยทdยฒ) โ ไธ่ฎญ็ป่ทฏๅพ็ธๅใ
|
| 466 |
if use_cache:
|
| 467 |
+
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 468 |
+
states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
|
| 469 |
+
|
| 470 |
+
# Add h0 contribution: S_t += (โ_{i=0}^{t} ฮฑ_i) ยท h0
|
| 471 |
+
# ๅ ๅ h0 ่ดก็ฎ: S_t += (โ_{i=0}^{t} ฮฑ_i) ยท h0
|
| 472 |
+
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1]
|
| 473 |
+
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1]
|
| 474 |
+
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 475 |
+
|
| 476 |
+
# Final state includes h0 contribution
|
| 477 |
+
# ๆ็ป็ถๆๅ
ๅซ h0 ่ดก็ฎ
|
| 478 |
+
total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,1,1]
|
| 479 |
+
S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d] (squeeze batch dim of h0)
|
| 480 |
+
# h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
|
| 481 |
+
final_state = (log_acc, S_final)
|
| 482 |
+
|
| 483 |
if monoid_cache is not None:
|
| 484 |
monoid_cache.update(self.layer_idx, final_state)
|
| 485 |
|
| 486 |
+
# Vectorized readout: o_t = q_t ยท S_t for all t
|
| 487 |
+
# ๅ้ๅ่ฏปๅบ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ o_t = q_t ยท S_t
|
| 488 |
+
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
|
| 489 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 490 |
return self.o_proj(o), final_state
|
| 491 |
|