OzTianlu commited on
Commit
1924b81
ยท
verified ยท
1 Parent(s): 64510f9

Upload 2 files

Browse files
Files changed (1) hide show
  1. MonoidForCausalLM.py +30 -19
MonoidForCausalLM.py CHANGED
@@ -483,28 +483,39 @@ class MonoidAttention(nn.Module):
483
  return self.o_proj(o), final_state
484
 
485
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
486
- # Training path: memory-efficient sequential scan + inline readout
487
- # ่ฎญ็ปƒ่ทฏๅพ„: ๅ†…ๅญ˜้ซ˜ๆ•ˆ็š„ไธฒ่กŒๆ‰ซๆ + ๅ†…่”่ฏปๅ‡บ
488
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
489
- # Loop token-by-token with running state S=[B,H,d,d].
490
- # Peak memory: O(BยทHยทdยฒ) instead of O(BยทHยทTยทdยฒ).
491
- # Autograd records each step for correct gradient computation.
 
492
  #
493
- # ้€ token ๅพช็Žฏ, ไฝฟ็”จ่ฟ่กŒ็Šถๆ€ S=[B,H,d,d]ใ€‚
494
- # ๅณฐๅ€ผๅ†…ๅญ˜: O(BยทHยทdยฒ) ่€Œ้ž O(BยทHยทTยทdยฒ)ใ€‚
495
- # Autograd ่ฎฐๅฝ•ๆฏๆญฅๆ“ไฝœไปฅๆญฃ็กฎ่ฎก็ฎ—ๆขฏๅบฆใ€‚
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
498
- o_parts = []
499
- for t in range(T):
500
- kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
501
- decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
502
- while decay.dim() < S.dim():
503
- decay = decay.unsqueeze(-1)
504
- S = S * decay + kv_t
505
- o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
506
-
507
- o = torch.stack(o_parts, dim=2) # [B,H,T,d]
508
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
509
  return self.o_proj(o), None
510
 
 
483
  return self.o_proj(o), final_state
484
 
485
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
486
+ # Training path: parallel scan + vectorized readout
487
+ # ่ฎญ็ปƒ่ทฏๅพ„: ๅนถ่กŒๆ‰ซๆ + ๅ‘้‡ๅŒ–่ฏปๅ‡บ
488
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
489
+ # Materialize full kv tensor [B,H,T,d,d] and scan in one pass.
490
+ # Memory: O(BยทHยทTยทdยฒ) โ€” trades memory for speed.
491
+ # Eliminates Tร—30 Python-loop kernel launches for outer product
492
+ # and readout; scan itself is parallel when CUDA kernel available.
493
  #
494
+ # ็‰ฉๅŒ–ๅฎŒๆ•ด kv ๅผ ้‡ [B,H,T,d,d] ๅนถไธ€ๆฌกๆ€งๆ‰ซๆใ€‚
495
+ # ๅ†…ๅญ˜: O(BยทHยทTยทdยฒ) โ€” ไปฅๅ†…ๅญ˜ๆข้€Ÿๅบฆใ€‚
496
+ # ๆถˆ้™คๅค–็งฏๅ’Œ่ฏปๅ‡บ็š„ Tร—30 ๆฌก Python ๅพช็Žฏ kernel launch;
497
+ # ๅฝ“ CUDA kernel ๅฏ็”จๆ—ถๆ‰ซๆๆœฌ่บซไนŸๆ˜ฏๅนถ่กŒ็š„ใ€‚
498
+
499
+ # Vectorized outer product: kv_t = k_t โŠ— v_t for all t at once
500
+ # ๅ‘้‡ๅŒ–ๅค–็งฏ: ไธ€ๆฌกๆ€ง่ฎก็ฎ—ๆ‰€ๆœ‰ t ็š„ k_t โŠ— v_t
501
+ kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
502
+
503
+ # Parallel prefix scan: S_t = ฮฑ_tยทS_{t-1} + kv_t (from S=0)
504
+ # ๅนถ่กŒๅ‰็ผ€ๆ‰ซๆ: S_t = ฮฑ_tยทS_{t-1} + kv_t (ไปŽ S=0 ๅผ€ๅง‹)
505
+ # Keep log_alpha as [B,H,T,1] โ€” CUDA kernel backward expects this shape.
506
+ # ไฟๆŒ log_alpha ไธบ [B,H,T,1] โ€” CUDA kernel ๅๅ‘ไผ ๆ’ญ้œ€่ฆๆญคๅฝข็Šถใ€‚
507
+ states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
508
+
509
+ # Add h0 contribution: S_t += (โˆ_{i=0}^{t} ฮฑ_i) ยท h0
510
+ # ๅ ๅŠ  h0 ่ดก็Œฎ: S_t += (โˆ_{i=0}^{t} ฮฑ_i) ยท h0
511
+ cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1]
512
+ h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1]
513
+ states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
514
+
515
+ # Vectorized readout: o_t = q_t ยท S_t for all t at once
516
+ # ๅ‘้‡ๅŒ–่ฏปๅ‡บ: ไธ€ๆฌกๆ€ง่ฎก็ฎ—ๆ‰€ๆœ‰ t ็š„ q_t ยท S_t
517
+ o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
518
 
 
 
 
 
 
 
 
 
 
 
 
519
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
520
  return self.o_proj(o), None
521