msj19 commited on
Commit
f5d4dfb
·
verified ·
1 Parent(s): 038c065

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla3/ops/common/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  2. fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  3. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-310.pyc +0 -0
  4. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc +0 -0
  5. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc +0 -0
  6. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc +0 -0
  7. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc +0 -0
  8. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
  9. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc +0 -0
  10. fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  11. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc +0 -0
  12. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc +0 -0
  13. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  14. fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +164 -0
  15. fla3/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-312.pyc +0 -0
  16. fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  17. fla3/ops/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla3/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  19. fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  20. fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc +0 -0
  21. fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  22. fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc +0 -0
  23. fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  24. fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc +0 -0
  25. fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  26. fla3/ops/lightning_attn/__pycache__/__init__.cpython-310.pyc +0 -0
  27. fla3/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  28. fla3/ops/lightning_attn/__pycache__/chunk.cpython-310.pyc +0 -0
  29. fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  30. fla3/ops/linear_attn/__pycache__/chunk.cpython-310.pyc +0 -0
  31. fla3/ops/nsa/__pycache__/__init__.cpython-310.pyc +0 -0
  32. fla3/ops/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  33. fla3/ops/nsa/__pycache__/compression.cpython-310.pyc +0 -0
  34. fla3/ops/nsa/__pycache__/naive.cpython-312.pyc +0 -0
  35. fla3/ops/nsa/__pycache__/parallel.cpython-310.pyc +0 -0
  36. fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  37. fla3/ops/path_attn/__pycache__/cumprod_householder_fwd.cpython-310.pyc +0 -0
  38. fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc +0 -0
  39. fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc +0 -0
  40. fla3/ops/path_attn/__pycache__/parallel.cpython-310.pyc +0 -0
  41. fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc +0 -0
  42. fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc +0 -0
  43. fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dqh.cpython-310.pyc +0 -0
  44. fla3/ops/path_attn/__pycache__/parallel_path_fwd.cpython-310.pyc +0 -0
  45. fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
  46. fla3/ops/rebased/__pycache__/__init__.cpython-310.pyc +0 -0
  47. fla3/ops/rebased/__pycache__/parallel.cpython-310.pyc +0 -0
  48. fla3/ops/retention/__pycache__/__init__.cpython-312.pyc +0 -0
  49. fla3/ops/retention/__pycache__/chunk.cpython-310.pyc +0 -0
  50. fla3/ops/retention/__pycache__/fused_chunk.cpython-310.pyc +0 -0
fla3/ops/common/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (8.55 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc ADDED
Binary file (25 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc ADDED
Binary file (7.6 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
12
+
13
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
14
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def prepare_wy_repr_bwd_kernel(
31
+ A_ab_inv,
32
+ A_ak,
33
+ ag,
34
+ v,
35
+ dw,
36
+ du,
37
+ dv,
38
+ dv0,
39
+ dag,
40
+ dAak,
41
+ dAab,
42
+ cu_seqlens,
43
+ chunk_indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ IS_VARLEN: tl.constexpr,
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if IS_VARLEN:
56
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
65
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+
67
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
68
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
69
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
70
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
71
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
72
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
73
+
74
+ for i_v in range(tl.cdiv(V, BV)):
75
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
78
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
79
+ b_v = tl.load(p_v, boundary_check=(0, 1))
80
+ b_du = tl.load(p_du, boundary_check=(0, 1))
81
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
82
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
83
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
84
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
85
+
86
+ m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :]
87
+ b_dA_tmp = tl.where(m_i, b_dA_tmp, 0)
88
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
89
+ b_dA_ak = tl.where(m_i, b_dA_ak, 0)
90
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
91
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
92
+
93
+ for i_k in range(tl.cdiv(K, BK)):
94
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
95
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
96
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
97
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
98
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
99
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
100
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
101
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
102
+
103
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
104
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
105
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
106
+ # denote A = I - lower(A_ab), B = A^-1
107
+ # in the backward pass.
108
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
109
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
110
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
111
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
112
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
113
+ b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0)
114
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
115
+
116
+
117
+ def chunk_dplr_bwd_wy(
118
+ A_ab_inv: torch.Tensor,
119
+ A_ak: torch.Tensor,
120
+ v: torch.Tensor,
121
+ ag: torch.Tensor,
122
+ dw: torch.Tensor,
123
+ du: torch.Tensor,
124
+ dv0: torch.Tensor,
125
+ cu_seqlens: Optional[torch.LongTensor],
126
+ chunk_size: int,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
129
+ B, T, H, K, V = *dw.shape, du.shape[-1]
130
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
131
+
132
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
133
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
134
+ BK = min(triton.next_power_of_2(K), 64)
135
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
136
+
137
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
138
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
139
+ dv = torch.empty_like(v)
140
+ dag = torch.empty_like(ag)
141
+
142
+ prepare_wy_repr_bwd_kernel[(NT, B * H)](
143
+ A_ab_inv=A_ab_inv,
144
+ A_ak=A_ak,
145
+ ag=ag,
146
+ v=v,
147
+ dw=dw,
148
+ du=du,
149
+ dv=dv,
150
+ dv0=dv0,
151
+ dag=dag,
152
+ dAak=dA_ak,
153
+ dAab=dA_ab,
154
+ cu_seqlens=cu_seqlens,
155
+ chunk_indices=chunk_indices,
156
+ T=T,
157
+ H=H,
158
+ K=K,
159
+ V=V,
160
+ BT=BT,
161
+ BK=BK,
162
+ BV=BV,
163
+ )
164
+ return dA_ab, dA_ak, dv, dag
fla3/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (24.6 kB). View file
 
fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (25.6 kB). View file
 
fla3/ops/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (371 Bytes). View file
 
fla3/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (66.3 kB). View file
 
fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (305 Bytes). View file
 
fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (320 Bytes). View file
 
fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla3/ops/lightning_attn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (338 Bytes). View file
 
fla3/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (353 Bytes). View file
 
fla3/ops/lightning_attn/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (2.92 kB). View file
 
fla3/ops/linear_attn/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
fla3/ops/nsa/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (291 Bytes). View file
 
fla3/ops/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (306 Bytes). View file
 
fla3/ops/nsa/__pycache__/compression.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
fla3/ops/nsa/__pycache__/naive.cpython-312.pyc ADDED
Binary file (5.78 kB). View file
 
fla3/ops/nsa/__pycache__/parallel.cpython-310.pyc ADDED
Binary file (20.4 kB). View file
 
fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (273 Bytes). View file
 
fla3/ops/path_attn/__pycache__/cumprod_householder_fwd.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc ADDED
Binary file (4.65 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel.cpython-310.pyc ADDED
Binary file (6.83 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (9.9 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dqh.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel_path_fwd.cpython-310.pyc ADDED
Binary file (4.95 kB). View file
 
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
fla3/ops/rebased/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (239 Bytes). View file
 
fla3/ops/rebased/__pycache__/parallel.cpython-310.pyc ADDED
Binary file (9.32 kB). View file
 
fla3/ops/retention/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (452 Bytes). View file
 
fla3/ops/retention/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (3.35 kB). View file
 
fla3/ops/retention/__pycache__/fused_chunk.cpython-310.pyc ADDED
Binary file (9.26 kB). View file