OzTianlu commited on
Commit
ff3696b
ยท
verified ยท
1 Parent(s): 1924b81

Upload 2 files

Browse files
Files changed (1) hide show
  1. MonoidForCausalLM.py +27 -20
MonoidForCausalLM.py CHANGED
@@ -454,31 +454,38 @@ class MonoidAttention(nn.Module):
454
  return self.o_proj(o), new_state
455
 
456
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
457
- # Inference prefill (use_cache=True, T>1): fused scan + readout
458
- # ๆŽจ็†้ข„ๅกซๅ…… (use_cache=True, T>1): ่žๅˆๆ‰ซๆ + ่ฏปๅ‡บ
459
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
460
- # Avoids materializing full [B,H,T,d,d] states tensor.
461
- # Peak memory: O(Hยทdยฒ) instead of O(TยทHยทdยฒ).
462
- # ้ฟๅ…ๅฎžไฝ“ๅŒ–ๅฎŒๆ•ด็š„ [B,H,T,d,d] ็Šถๆ€ๅผ ้‡ใ€‚
463
- # ๅณฐๅ€ผๅ†…ๅญ˜: O(Hยทdยฒ) ่€Œ้ž O(TยทHยทdยฒ)ใ€‚
 
 
464
  if use_cache:
465
- S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
466
- log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype)
467
- o_parts = []
468
- for t in range(T):
469
- kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
470
- decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
471
- while decay.dim() < S.dim():
472
- decay = decay.unsqueeze(-1)
473
- S = S * decay + kv_t
474
- o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
475
- log_acc = log_acc + log_alpha[:, :, t]
476
-
477
- final_state = (log_acc, S)
 
 
 
478
  if monoid_cache is not None:
479
  monoid_cache.update(self.layer_idx, final_state)
480
 
481
- o = torch.stack(o_parts, dim=2) # [B,H,T,d]
 
 
482
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
483
  return self.o_proj(o), final_state
484
 
 
454
  return self.o_proj(o), new_state
455
 
456
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
457
+ # Inference prefill (use_cache=True, T>1): parallel scan + readout
458
+ # ๆŽจ็†้ข„ๅกซๅ…… (use_cache=True, T>1): ๅนถ่กŒๆ‰ซๆ + ่ฏปๅ‡บ
459
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
460
+ # Uses the same parallel_scan_with_state as training to leverage
461
+ # Triton CUDA kernel acceleration instead of O(T) Python loop.
462
+ # Memory: O(BยทHยทTยทdยฒ) โ€” same as training path.
463
+ # ไฝฟ็”จไธŽ่ฎญ็ปƒ็›ธๅŒ็š„ parallel_scan_with_state ๆฅๅˆฉ็”จ
464
+ # Triton CUDA ๆ ธๅ‡ฝๆ•ฐๅŠ ้€Ÿ, ่€Œ้ž O(T) ็š„ Python ๅพช็Žฏใ€‚
465
+ # ๅ†…ๅญ˜: O(BยทHยทTยทdยฒ) โ€” ไธŽ่ฎญ็ปƒ่ทฏๅพ„็›ธๅŒใ€‚
466
  if use_cache:
467
+ kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
468
+ states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
469
+
470
+ # Add h0 contribution: S_t += (โˆ_{i=0}^{t} ฮฑ_i) ยท h0
471
+ # ๅ ๅŠ  h0 ่ดก็Œฎ: S_t += (โˆ_{i=0}^{t} ฮฑ_i) ยท h0
472
+ cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1]
473
+ h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1]
474
+ states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
475
+
476
+ # Final state includes h0 contribution
477
+ # ๆœ€็ปˆ็Šถๆ€ๅŒ…ๅซ h0 ่ดก็Œฎ
478
+ total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,1,1]
479
+ S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d] (squeeze batch dim of h0)
480
+ # h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
481
+ final_state = (log_acc, S_final)
482
+
483
  if monoid_cache is not None:
484
  monoid_cache.update(self.layer_idx, final_state)
485
 
486
+ # Vectorized readout: o_t = q_t ยท S_t for all t
487
+ # ๅ‘้‡ๅŒ–่ฏปๅ‡บ: ไธ€ๆฌกๆ€ง่ฎก็ฎ—ๆ‰€ๆœ‰ t ็š„ o_t = q_t ยท S_t
488
+ o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
489
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
490
  return self.o_proj(o), final_state
491