OzTianlu commited on
Commit
64510f9
ยท
verified ยท
1 Parent(s): 471fab3

Upload 2 files

Browse files
Files changed (1) hide show
  1. MonoidForCausalLM.py +21 -34
MonoidForCausalLM.py CHANGED
@@ -483,43 +483,30 @@ class MonoidAttention(nn.Module):
483
  return self.o_proj(o), final_state
484
 
485
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
486
- # Training path (parallel scan): O(T) via prefix sum
487
- # ่ฎญ็ปƒ่ทฏๅพ„ (ๅนถ่กŒๆ‰ซๆ): ้€š่ฟ‡ๅ‰็ผ€ๅ’Œ O(T)
488
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
489
- # For a full sequence of length T, compute ALL prefix states
490
- # S_1, S_2, ..., S_T simultaneously using parallel prefix scan.
491
- # Complexity: O(T) work, O(log T) depth โ€” GPU-friendly.
492
  #
493
- # ๅฏน้•ฟๅบฆไธบ T ็š„ๅฎŒๆ•ดๅบๅˆ—, ไฝฟ็”จๅนถ่กŒๅ‰็ผ€ๆ‰ซๆๅŒๆ—ถ่ฎก็ฎ—ๆ‰€ๆœ‰ๅ‰็ผ€็Šถๆ€
494
- # S_1, S_2, ..., S_Tใ€‚
495
- # ๅคๆ‚ๅบฆ: O(T) ๅทฅไฝœ้‡, O(log T) ๆทฑๅบฆ โ€” GPU ๅ‹ๅฅฝใ€‚
496
-
497
- # Batch outer product: kv_{t} = k_t โŠ— v_t for all t
498
- # ๆ‰น้‡ๅค–็งฏ: kv_{t} = k_t โŠ— v_t, ๅฏนๆ‰€ๆœ‰ t
499
- kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
500
- states = parallel_scan(log_alpha, kv)
501
- del kv # free [B,H,T,d,d] early
502
- final_state = None
503
-
504
- # โ”€โ”€ Incorporate h0: make training consistent with inference โ”€โ”€
505
- # โ”€โ”€ ่žๅ…ฅ h0: ไฝฟ่ฎญ็ปƒไธŽๆŽจ็†ไธ€่‡ด โ”€โ”€
506
- # parallel_scan starts from S_0 = 0, but inference starts from S_0 = h0.
507
- # Fix: S_t(with h0) = h0 ยท ฮ _{i=1}^{t} ฮฑ_i + S_t(from scan)
508
- # The cumulative decay ฮ _{i=1}^{t} ฮฑ_i = exp(ฮฃ_{i=1}^{t} log_ฮฑ_i).
509
- # parallel_scan ไปŽ S_0 = 0 ๅผ€ๅง‹, ไฝ†ๆŽจ็†ไปŽ S_0 = h0 ๅผ€ๅง‹ใ€‚
510
- # ไฟฎๆญฃ: S_t(ๅซh0) = h0 ยท ฮ _{i=1}^{t} ฮฑ_i + S_t(ๆ‰ซๆ็ป“ๆžœ)
511
- # ็ดฏ็งฏ่กฐๅ‡ ฮ _{i=1}^{t} ฮฑ_i = exp(ฮฃ_{i=1}^{t} log_ฮฑ_i)ใ€‚
512
- cum_log_decay = torch.cumsum(log_alpha.squeeze(-1), dim=2) # [B,H,T]
513
- cum_decay = torch.exp(cum_log_decay).unsqueeze(-1).unsqueeze(-1) # [B,H,T,1,1]
514
- states = states + self.h0.unsqueeze(2) * cum_decay # [B,H,T,d,d]
515
- del cum_decay
516
-
517
- # Readout: o_t = q_t ยท S_t for all t simultaneously
518
- # ่ฏปๅ‡บ: o_t = q_t ยท S_t, ๅฏนๆ‰€ๆœ‰ t ๅŒๆ—ถ่ฎก็ฎ—
519
- o = torch.einsum('bhtd, bhtde -> bhte', q, states)
520
- del states # free [B,H,T,d,d]
521
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
522
- return self.o_proj(o), final_state
523
 
524
 
525
  # โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
 
483
  return self.o_proj(o), final_state
484
 
485
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
486
+ # Training path: memory-efficient sequential scan + inline readout
487
+ # ่ฎญ็ปƒ่ทฏๅพ„: ๅ†…ๅญ˜้ซ˜ๆ•ˆ็š„ไธฒ่กŒๆ‰ซๆ + ๅ†…่”่ฏปๅ‡บ
488
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
489
+ # Loop token-by-token with running state S=[B,H,d,d].
490
+ # Peak memory: O(BยทHยทdยฒ) instead of O(BยทHยทTยทdยฒ).
491
+ # Autograd records each step for correct gradient computation.
492
  #
493
+ # ้€ token ๅพช็Žฏ, ไฝฟ็”จ่ฟ่กŒ็Šถๆ€ S=[B,H,d,d]ใ€‚
494
+ # ๅณฐๅ€ผๅ†…ๅญ˜: O(BยทHยทdยฒ) ่€Œ้ž O(BยทHยทTยทdยฒ)ใ€‚
495
+ # Autograd ่ฎฐๅฝ•ๆฏๆญฅๆ“ไฝœไปฅๆญฃ็กฎ่ฎก็ฎ—ๆขฏๅบฆใ€‚
496
+
497
+ S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
498
+ o_parts = []
499
+ for t in range(T):
500
+ kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
501
+ decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
502
+ while decay.dim() < S.dim():
503
+ decay = decay.unsqueeze(-1)
504
+ S = S * decay + kv_t
505
+ o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
506
+
507
+ o = torch.stack(o_parts, dim=2) # [B,H,T,d]
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
509
+ return self.o_proj(o), None
510
 
511
 
512
  # โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”