Lekr0 commited on
Commit
8ef1d10
·
verified ·
1 Parent(s): 7cd1cbc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/cache/compiled_kernels/2e/c2etayrlw6ivbtj3uahv4l3y7x534xpzfww6cyknbe2kfe54yei5.py +43 -0
  2. SpecForge-ext/cache/compiled_kernels/2h/c2hvdjlmxyob2txn4nddktnqpzxakuy4vukk46jxvlks5plszr5r.py +62 -0
  3. SpecForge-ext/cache/compiled_kernels/2s/c2sasa5yimiwlxmywmcvgtuh2fvol2mvhppzairkbqvuwicnbd5y.py +62 -0
  4. SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py +835 -0
  5. SpecForge-ext/cache/compiled_kernels/2x/3dd4effcc6c7612a42d28cac3a6342345062808f2904d114a779d751ce7956b2.best_config +1 -0
  6. SpecForge-ext/cache/compiled_kernels/2x/c2xgz3ru7j7sptpmoelww3e5lkmoeimpyawjjwmcpaujxtdorhwr.py +56 -0
  7. SpecForge-ext/cache/compiled_kernels/2x/c2xsu5ssb3jappbwwrbr53muiaoukfjzccks7reewucgvplouktq.py +43 -0
  8. SpecForge-ext/cache/compiled_kernels/2x/c2xunts4zntd65pabgkkxg5ylyh7sahfyogzmgljfiljdui4o365.py +62 -0
  9. SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py +56 -0
  10. SpecForge-ext/cache/compiled_kernels/3k/c0fc7bc81a7e9d406f980957c0881903e8484dd7f57d970f2ddd21ca3ab2994d.best_config +1 -0
  11. SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py +56 -0
  12. SpecForge-ext/cache/compiled_kernels/3m/c3mfnz3jqpdzlott45yvd2kki53nhik366siiuob2jitdkwx6tyg.py +46 -0
  13. SpecForge-ext/cache/compiled_kernels/3p/c3pmafpvrty43do4nz3cf2mvhkihfulfxbiolmcu2votxja4s56e.py +352 -0
  14. SpecForge-ext/cache/compiled_kernels/3u/c3ukv75kqyf3oeeogojmsgmsebbc2fg3rqs4dsmnshhsgj4hjkzx.py +168 -0
  15. SpecForge-ext/cache/compiled_kernels/3x/88057732cb1d7a775c254455fe42105016cd2d1ced3af1bd1fb079691b5972a1.best_config +1 -0
  16. SpecForge-ext/cache/compiled_kernels/3x/c3xxifdzdkxpgs4yujq2dd2lwhcylvzgstk2hao7g2yjirkxlafr.py +86 -0
  17. SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py +45 -0
  18. SpecForge-ext/cache/compiled_kernels/42/c42olsblh7ymaib2tr5gwhfzuighing5bkpmabq5hx7nxumtbsig.py +835 -0
  19. SpecForge-ext/cache/compiled_kernels/44/0525feab4902a63d7e5c68635e4f503cce07497c146c803e2f852ae21bd67e9c.best_config +1 -0
  20. SpecForge-ext/cache/compiled_kernels/44/c444sn4254wny52itl5mlassuxuptb3kc3p6r3r3stsx4lyt6t3r.py +527 -0
  21. SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py +56 -0
  22. SpecForge-ext/cache/compiled_kernels/47/c477sca7o5mbhj2pknepbw5b3rzush4uzefidcyxm6ysescvabgf.py +48 -0
  23. SpecForge-ext/cache/compiled_kernels/47/c47ib4wwcplmaudqv7246conlxcjjylxi5ahlxye6ebv6onrgoxg.py +37 -0
  24. SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py +46 -0
  25. SpecForge-ext/cache/compiled_kernels/4d/71d6529f8e555bde19385922da8b4def675c0510b5b664b1bf4faebe9f8928eb.best_config +1 -0
  26. SpecForge-ext/cache/compiled_kernels/4d/c4dlhtn4sdcx4l7dqawohb4u6hvu3xnzxvwazc5qrp6haqgnm3ev.py +24 -0
  27. SpecForge-ext/cache/compiled_kernels/4d/c4dw6ykxgbwk3glacutxkpzwhapvr5oszjet3n4i4q3snumjzm3x.py +835 -0
  28. SpecForge-ext/cache/compiled_kernels/57/c57svaeo74ael4oxqveudfvhx4xfmu3ikrmvljcnixb4kiqagrzn.py +47 -0
  29. SpecForge-ext/cache/compiled_kernels/5b/c5bepmksmwsj5jf67nbhozcficyzxyuodtllwlc5wms64ubwgqh6.py +835 -0
  30. SpecForge-ext/cache/compiled_kernels/5b/c5bjpwebwz42wx5vrxljhccvdqdyttlxd2hpbtzw7ia3oq6c33ne.py +62 -0
  31. SpecForge-ext/cache/compiled_kernels/5g/c5g26r4ygcctmxuptx453t3kikkqukh73touvd4yxv7futs36kgf.py +66 -0
  32. SpecForge-ext/cache/compiled_kernels/5g/c5g4egnryommgtc4braxeh3xxhfypsb6ec3v4sx3gxmfoeho5bzl.py +1083 -0
  33. SpecForge-ext/cache/compiled_kernels/5g/f7b122e9e44d29c2d695b9633c17bd6eee619900dd2325a29224a34ca8164da3.best_config +1 -0
  34. SpecForge-ext/cache/compiled_kernels/5j/c5jkwkcyjhbh7wkjfrwxf42i764iqjvwig2sk7me5dmfpiuqgldn.py +410 -0
  35. SpecForge-ext/cache/compiled_kernels/5j/c5jvrblz5ym34kn4ssfnzfxabvx53ffrcnqlmwnjv3gmkqfkgo4v.py +1083 -0
  36. SpecForge-ext/cache/compiled_kernels/5o/c5oswnr7dpwwcqp5m7thpaf4owvpseka6amdvdroewnorzsre6n2.py +56 -0
  37. SpecForge-ext/cache/compiled_kernels/5o/fd9062ce8e19c42a2ac9826803e021d98494b253c62ff7bfe753f34e0c863929.best_config +1 -0
  38. SpecForge-ext/cache/compiled_kernels/5r/67cf09099929c6923eb0884c24406ef33ddeccd796aae57f0484cb9e81164741.best_config +1 -0
  39. SpecForge-ext/cache/compiled_kernels/5r/a92927e5b439ed1b110bb08d838b1d456f558deda739d61f97499093c88a877a.best_config +1 -0
  40. SpecForge-ext/cache/compiled_kernels/5r/c5rruemruzlohybkl4bagtqtk5athtuxf3eoj37rjptdltrrri3r.py +56 -0
  41. SpecForge-ext/cache/compiled_kernels/5r/c5rs7ak7dbb5csdzszewryjtnxnhv7xpwdjvgqrsp5lfmmej4poi.py +24 -0
  42. SpecForge-ext/cache/compiled_kernels/5w/c5w735qbviioww7vfjj36tk57xo254oei3wqkunaiekkjd5pfcph.py +43 -0
  43. SpecForge-ext/cache/compiled_kernels/5z/c5z5oj5ee2bvvg2pkzwf6smszdy73565nillm7gopvokmvrvu2dp.py +711 -0
  44. SpecForge-ext/cache/compiled_kernels/6j/c6jzxztdxbjv5b23nfmgzgtizqp77h7aeak5j2jukmz3roqeiw3k.py +24 -0
  45. SpecForge-ext/cache/compiled_kernels/6j/e9590d30530b6f20cd8332cd18dfb56bc33c5ce0f73ebafd83fbd8da1a7ab8fe.best_config +1 -0
  46. SpecForge-ext/cache/compiled_kernels/6m/c6mwcfy2ykv3p5alrzh4sx4ajhl5davetqobw2pytyc2kalbo2wk.py +168 -0
  47. SpecForge-ext/cache/compiled_kernels/6o/c6o7jlqhfbi4ry3uni47hefilsmfptqopfdxwc3plgg65s2mqzse.py +322 -0
  48. SpecForge-ext/cache/compiled_kernels/6r/734ee9f72fcbbc036c304bd9fc428175dc6febf6da61f182679d20ad4d8b7f41.best_config +1 -0
  49. SpecForge-ext/cache/compiled_kernels/6r/c6r6adrqwwhzfcdd5cyhmwl3cptpvwwhedzdpranw7esxeg5oyia.py +56 -0
  50. SpecForge-ext/cache/compiled_kernels/6r/c6rbvgm53jr3nux66durqhisanccgaebzxcdjdhdrphqjpyu2t5r.py +62 -0
SpecForge-ext/cache/compiled_kernels/2e/c2etayrlw6ivbtj3uahv4l3y7x534xpzfww6cyknbe2kfe54yei5.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 128, 'r0_': 16},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tmp0.to(tl.int64)
36
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
37
+ tmp4 = _tmp3 + tmp2
38
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
39
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
40
+ x2 = (xindex % ks1)
41
+ x3 = xindex // ks1
42
+ tmp5 = tmp3.to(tl.int32)
43
+ tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask)
SpecForge-ext/cache/compiled_kernels/2h/c2hvdjlmxyob2txn4nddktnqpzxakuy4vukk46jxvlks5plszr5r.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
35
+ tmp1 = tmp0.to(tl.float32)
36
+ tmp2 = tmp1 * tmp1
37
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
38
+ tmp5 = _tmp4 + tmp3
39
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
40
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
41
+ tmp9 = in_ptr1
42
+ tmp6 = ks0
43
+ tmp7 = tmp6.to(tl.float32)
44
+ tmp8 = (tmp4 / tmp7)
45
+ tmp10 = tmp9.to(tl.float32)
46
+ tmp11 = tmp8 + tmp10
47
+ tmp12 = libdevice.rsqrt(tmp11)
48
+ tl.debug_barrier()
49
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
50
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
51
+ r0_index = r0_offset + r0_base
52
+ r0_mask = r0_index < r0_numel
53
+ roffset = r0_offset
54
+ rindex = r0_index
55
+ r0_1 = r0_index
56
+ tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
57
+ tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp15 = tmp14.to(tl.float32)
59
+ tmp16 = tmp15 * tmp12
60
+ tmp17 = tmp16.to(tl.float32)
61
+ tmp18 = tmp13 * tmp17
62
+ tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/2s/c2sasa5yimiwlxmywmcvgtuh2fvol2mvhppzairkbqvuwicnbd5y.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
29
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
37
+ tmp1 = tmp0.to(tl.float32)
38
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
39
+
40
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
41
+ _tmp3_max, _tmp3_sum, tmp2, False
42
+ )
43
+
44
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
45
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
46
+
47
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
48
+ _tmp3_max, _tmp3_sum, 1, False)
49
+ tmp3 = tmp3[:, None]
50
+ tmp4 = tmp4[:, None]
51
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
52
+ r0_index = r0_offset + r0_base
53
+ r0_mask = r0_index < r0_numel
54
+ roffset = r0_offset
55
+ rindex = r0_index
56
+ r0_1 = r0_index
57
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp6 = tmp5.to(tl.float32)
59
+ tmp7 = tmp6 - tmp3
60
+ tmp8 = libdevice.exp(tmp7)
61
+ tmp9 = (tmp8 / tmp4)
62
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 8
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = ks8
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = ks0
596
+ KV_LEN = ks1
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = ks8
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/2x/3dd4effcc6c7612a42d28cac3a6342345062808f2904d114a779d751ce7956b2.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"}
SpecForge-ext/cache/compiled_kernels/2x/c2xgz3ru7j7sptpmoelww3e5lkmoeimpyawjjwmcpaujxtdorhwr.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/2x/c2xsu5ssb3jappbwwrbr53muiaoukfjzccks7reewucgvplouktq.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 32, 'r0_': 16},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tmp0.to(tl.int64)
36
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
37
+ tmp4 = _tmp3 + tmp2
38
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
39
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
40
+ x2 = (xindex % ks1)
41
+ x3 = xindex // ks1
42
+ tmp5 = tmp3.to(tl.int32)
43
+ tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask)
SpecForge-ext/cache/compiled_kernels/2x/c2xunts4zntd65pabgkkxg5ylyh7sahfyogzmgljfiljdui4o365.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 128, 'r0_': 32},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 128
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_1 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
36
+ tmp1 = tmp0.to(tl.int64)
37
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
38
+ tmp4 = _tmp3 + tmp2
39
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
40
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
41
+ tmp5 = tmp3.to(tl.int32)
42
+ tl.store(out_ptr1 + (x0), tmp5, xmask)
43
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
44
+ r0_index = r0_offset + r0_base
45
+ r0_mask = r0_index < r0_numel
46
+ roffset = r0_offset
47
+ rindex = r0_index
48
+ r0_1 = r0_index
49
+ tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
50
+ tmp7 = tmp6.to(tl.int32)
51
+ tmp8 = r0_1
52
+ tmp9 = tmp8 < tmp5
53
+ tmp10 = ks0
54
+ tmp11 = tl.where(tmp9, tmp7, tmp10)
55
+ tmp12 = 1 + ks0
56
+ tmp13 = tmp11 + tmp12
57
+ tmp14 = tmp11 < 0
58
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
59
+ tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))")
60
+ tmp17 = tl.full([1, 1], 1, tl.int32)
61
+ tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask)
62
+ tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 262144},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 151936
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
29
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
39
+ _tmp2, _tmp2_index, tmp1, rindex
40
+ )
41
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
42
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
43
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
44
+ tmp2 = tmp2_idx[:, None]
45
+ tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
46
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
47
+ tmp4 = tmp2 + tmp3
48
+ tmp5 = tmp2 < 0
49
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
50
+ tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936")
51
+ tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1)
52
+ tmp9 = tmp8.to(tl.int32)
53
+ tmp10 = tmp9.to(tl.int64)
54
+ tmp12 = tmp10 * tmp11
55
+ tl.debug_barrier()
56
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
SpecForge-ext/cache/compiled_kernels/3k/c0fc7bc81a7e9d406f980957c0881903e8484dd7f57d970f2ddd21ca3ab2994d.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"}
SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/3m/c3mfnz3jqpdzlott45yvd2kki53nhik366siiuob2jitdkwx6tyg.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_0 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
36
+ tmp3 = _tmp2 + tmp1
37
+ _tmp2 = tl.where(r0_mask, tmp3, _tmp2)
38
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
39
+ tmp4 = tl.load(in_ptr1 + (0))
40
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1])
41
+ tmp6 = tmp5.to(tl.float32)
42
+ tmp7 = tmp2.to(tl.float32)
43
+ tmp8 = 1e-06
44
+ tmp9 = triton_helpers.maximum(tmp7, tmp8)
45
+ tmp10 = (tmp6 / tmp9)
46
+ tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None)
SpecForge-ext/cache/compiled_kernels/3p/c3pmafpvrty43do4nz3cf2mvhkihfulfxbiolmcu2votxja4s56e.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['14_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ob/cob65ptxwcswkyjowvaxmwnu4cpoiijoxwce6eyz2ndtpqxwqxm5.py
38
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
39
+ # Source node to ATen node mapping:
40
+ # argmax => argmax
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:4" = PlaceHolder[target=arg1_1]
43
+ # %argmax : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
44
+ # return %argmax
45
+ triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', '''
46
+ import triton
47
+ import triton.language as tl
48
+
49
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
50
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
51
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
52
+ triton_helpers.set_driver_to_gpu()
53
+
54
+ @triton_heuristics.reduction(
55
+ size_hints={'x': 4096, 'r0_': 32768},
56
+ reduction_hint=ReductionHint.INNER,
57
+ filename=__file__,
58
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
59
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
60
+ )
61
+ @triton.jit
62
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
63
+ r0_numel = 32000
64
+ rnumel = r0_numel
65
+ RBLOCK: tl.constexpr = R0_BLOCK
66
+ xoffset = tl.program_id(0) * XBLOCK
67
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
68
+ xmask = xindex < xnumel
69
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
70
+ rbase = r0_base
71
+ x0 = xindex
72
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
73
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
74
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
75
+ r0_index = r0_offset + r0_base
76
+ r0_mask = r0_index < r0_numel
77
+ roffset = r0_offset
78
+ rindex = r0_index
79
+ r0_1 = r0_index
80
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
81
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
82
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
83
+ _tmp2, _tmp2_index, tmp1, rindex
84
+ )
85
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
86
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
87
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
88
+ tmp2 = tmp2_idx[:, None]
89
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
90
+ ''', device_str='cuda')
91
+
92
+
93
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qj/cqj5277ktaoo5rg4kvnn7pm72cbfiwp7hxewmxzj4aevxoorlebn.py
94
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
95
+ # Source node to ATen node mapping:
96
+ # argmax_1 => argmax_1
97
+ # Graph fragment:
98
+ # %arg3_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:4" = PlaceHolder[target=arg3_1]
99
+ # %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {})
100
+ # return %argmax_1
101
+ triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', '''
102
+ import triton
103
+ import triton.language as tl
104
+
105
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
106
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
107
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
108
+ triton_helpers.set_driver_to_gpu()
109
+
110
+ @triton_heuristics.reduction(
111
+ size_hints={'x': 4096, 'r0_': 32768},
112
+ reduction_hint=ReductionHint.INNER,
113
+ filename=__file__,
114
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
115
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
116
+ )
117
+ @triton.jit
118
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
119
+ r0_numel = 32000
120
+ rnumel = r0_numel
121
+ RBLOCK: tl.constexpr = R0_BLOCK
122
+ xoffset = tl.program_id(0) * XBLOCK
123
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
124
+ xmask = xindex < xnumel
125
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
126
+ rbase = r0_base
127
+ x0 = (xindex % ks0)
128
+ x1 = xindex // ks0
129
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
130
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
131
+ x3 = xindex
132
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
133
+ r0_index = r0_offset + r0_base
134
+ r0_mask = r0_index < r0_numel
135
+ roffset = r0_offset
136
+ rindex = r0_index
137
+ r0_2 = r0_index
138
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
139
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
140
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
141
+ _tmp2, _tmp2_index, tmp1, rindex
142
+ )
143
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
144
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
145
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
146
+ tmp2 = tmp2_idx[:, None]
147
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
148
+ ''', device_str='cuda')
149
+
150
+
151
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sx/csxbkyhsnglm4cv6i6ibhzv34wlbjpntyk3gj27zbyc4s4efmtxh.py
152
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
153
+ # Source node to ATen node mapping:
154
+ # eq => eq_2
155
+ # mul => mul_3
156
+ # squeeze => squeeze
157
+ # sum_1 => sum_1
158
+ # Graph fragment:
159
+ # %argmax : Tensor "i64[2, s3][s3, 1]cuda:4" = PlaceHolder[target=argmax]
160
+ # %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:4" = PlaceHolder[target=argmax_1]
161
+ # %arg4_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:4" = PlaceHolder[target=arg4_1]
162
+ # %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
163
+ # %squeeze : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {})
164
+ # %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {})
165
+ # %sum_1 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {})
166
+ # return %sum_1
167
+ triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', '''
168
+ import triton
169
+ import triton.language as tl
170
+
171
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
172
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
173
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
174
+ triton_helpers.set_driver_to_gpu()
175
+
176
+ @triton_heuristics.reduction(
177
+ size_hints={'x': 1, 'r0_': 4096},
178
+ reduction_hint=ReductionHint.INNER,
179
+ filename=__file__,
180
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
181
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
182
+ )
183
+ @triton.jit
184
+ def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
185
+ xnumel = 1
186
+ rnumel = r0_numel
187
+ RBLOCK: tl.constexpr = R0_BLOCK
188
+ xoffset = tl.program_id(0) * XBLOCK
189
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
190
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
191
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
192
+ rbase = r0_base
193
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
194
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
195
+ r0_index = r0_offset + r0_base
196
+ r0_mask = r0_index < r0_numel
197
+ roffset = r0_offset
198
+ rindex = r0_index
199
+ r0_0 = r0_index
200
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
201
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
202
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
203
+ tmp2 = tmp0 == tmp1
204
+ tmp3 = tmp2.to(tl.int64)
205
+ tmp5 = tmp3 * tmp4
206
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
207
+ tmp8 = _tmp7 + tmp6
208
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
209
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
210
+ tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None)
211
+ ''', device_str='cuda')
212
+
213
+
214
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xe/cxeou6auzbu4dnrn2twxe573bmqovq7xnk4b6hydfbw53px4etc7.py
215
+ # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div]
216
+ # Source node to ATen node mapping:
217
+ # clamp_min => clamp_min
218
+ # sum_2 => sum_2
219
+ # truediv => div
220
+ # Graph fragment:
221
+ # %arg6_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:4" = PlaceHolder[target=arg6_1]
222
+ # %sum_1 : Tensor "i64[][]cuda:4" = PlaceHolder[target=sum_1]
223
+ # %sum_2 : Tensor "i64[][]cuda:4" = PlaceHolder[target=sum_2]
224
+ # %sum_2 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {})
225
+ # %clamp_min : Tensor "f32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {})
226
+ # %div : Tensor "f32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {})
227
+ # return %sum_2,%div
228
+ triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', '''
229
+ import triton
230
+ import triton.language as tl
231
+
232
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
233
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
234
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
235
+ triton_helpers.set_driver_to_gpu()
236
+
237
+ @triton_heuristics.reduction(
238
+ size_hints={'x': 1, 'r0_': 4096},
239
+ reduction_hint=ReductionHint.INNER,
240
+ filename=__file__,
241
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
242
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
243
+ )
244
+ @triton.jit
245
+ def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
246
+ xnumel = 1
247
+ rnumel = r0_numel
248
+ RBLOCK: tl.constexpr = R0_BLOCK
249
+ xoffset = tl.program_id(0) * XBLOCK
250
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
251
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
252
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
253
+ rbase = r0_base
254
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
255
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
256
+ r0_index = r0_offset + r0_base
257
+ r0_mask = r0_index < r0_numel
258
+ roffset = r0_offset
259
+ rindex = r0_index
260
+ r0_0 = r0_index
261
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
262
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
263
+ tmp3 = _tmp2 + tmp1
264
+ _tmp2 = tl.where(r0_mask, tmp3, _tmp2)
265
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
266
+ tmp4 = tl.load(in_ptr1 + (0))
267
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1])
268
+ tmp6 = tmp5.to(tl.float32)
269
+ tmp7 = tmp2.to(tl.float32)
270
+ tmp8 = 1e-06
271
+ tmp9 = triton_helpers.maximum(tmp7, tmp8)
272
+ tmp10 = (tmp6 / tmp9)
273
+ tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None)
274
+ ''', device_str='cuda')
275
+
276
+
277
+ async_compile.wait(globals())
278
+ del async_compile
279
+
280
+ class Runner:
281
+ def __init__(self, partitions):
282
+ self.partitions = partitions
283
+
284
+ def recursively_apply_fns(self, fns):
285
+ new_callables = []
286
+ for fn, c in zip(fns, self.partitions):
287
+ new_callables.append(fn(c))
288
+ self.partitions = new_callables
289
+
290
+ def call(self, args):
291
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args
292
+ args.clear()
293
+ s3 = arg0_1
294
+ s71 = arg2_1
295
+ s14 = arg5_1
296
+ assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1))
297
+ assert_size_stride(arg3_1, (2, s3, 32000), (s71, 32000, 1))
298
+ assert_size_stride(arg4_1, (2, s3, 1), (s3, 1, 1))
299
+ assert_size_stride(arg6_1, (2, s14, 1), (s14, 1, 1))
300
+ with torch.cuda._DeviceGuard(4):
301
+ torch.cuda.set_device(4)
302
+ buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64)
303
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
304
+ triton_red_fused_argmax_0_xnumel = 2*s3
305
+ stream4 = get_raw_stream(4)
306
+ triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream4)
307
+ del arg1_1
308
+ buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64)
309
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
310
+ triton_red_fused_argmax_1_xnumel = 2*s3
311
+ stream4 = get_raw_stream(4)
312
+ triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream4)
313
+ del arg3_1
314
+ buf2 = empty_strided_cuda((), (), torch.int64)
315
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
316
+ triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3
317
+ stream4 = get_raw_stream(4)
318
+ triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream4)
319
+ del arg4_1
320
+ del buf0
321
+ del buf1
322
+ buf4 = empty_strided_cuda((), (), torch.float32)
323
+ # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div]
324
+ triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14
325
+ stream4 = get_raw_stream(4)
326
+ triton_red_fused_clamp_min_div_sum_3.run(arg6_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream4)
327
+ del arg6_1
328
+ del buf2
329
+ return (buf4, )
330
+
331
+ runner = Runner(partitions=[])
332
+ call = runner.call
333
+ recursively_apply_fns = runner.recursively_apply_fns
334
+
335
+
336
+ def benchmark_compiled_module(times=10, repeat=10):
337
+ from torch._dynamo.testing import rand_strided
338
+ from torch._inductor.utils import print_performance
339
+ arg0_1 = 1543
340
+ arg1_1 = rand_strided((2, 1543, 32000), (49376000, 32000, 1), device='cuda:4', dtype=torch.bfloat16)
341
+ arg2_1 = 49600000
342
+ arg3_1 = rand_strided((2, 1543, 32000), (49600000, 32000, 1), device='cuda:4', dtype=torch.float32)
343
+ arg4_1 = rand_strided((2, 1543, 1), (1543, 1, 1), device='cuda:4', dtype=torch.int64)
344
+ arg5_1 = 1543
345
+ arg6_1 = rand_strided((2, 1543, 1), (1543, 1, 1), device='cuda:4', dtype=torch.int64)
346
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1])
347
+ return print_performance(fn, times=times, repeat=repeat)
348
+
349
+
350
+ if __name__ == "__main__":
351
+ from torch._inductor.wrapper_benchmark import compiled_module_main
352
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/3u/c3ukv75kqyf3oeeogojmsgmsebbc2fg3rqs4dsmnshhsgj4hjkzx.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['10_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/js/cjse6ak6jsp3o35wdszmvjyn4cqeqewbex3a5ks2m6fqecygrmmg.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul_2
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:0" = PlaceHolder[target=arg1_1]
47
+ # %argmax : Tensor "i64[2, s14][s14, 1]cuda:0" = PlaceHolder[target=argmax]
48
+ # %arg2_1 : Tensor "b8[151936][1]cuda:0" = PlaceHolder[target=arg2_1]
49
+ # %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg3_1]
50
+ # %argmax : Tensor "i64[2, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[2, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {})
55
+ # return %argmax,%mul_2
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 4096, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ r0_numel = 151936
75
+ rnumel = r0_numel
76
+ RBLOCK: tl.constexpr = R0_BLOCK
77
+ xoffset = tl.program_id(0) * XBLOCK
78
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
79
+ xmask = xindex < xnumel
80
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
81
+ rbase = r0_base
82
+ x0 = xindex
83
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
84
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
85
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
86
+ r0_index = r0_offset + r0_base
87
+ r0_mask = r0_index < r0_numel
88
+ roffset = r0_offset
89
+ rindex = r0_index
90
+ r0_1 = r0_index
91
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
92
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
93
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
94
+ _tmp2, _tmp2_index, tmp1, rindex
95
+ )
96
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
97
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
98
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
99
+ tmp2 = tmp2_idx[:, None]
100
+ tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
101
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
102
+ tmp4 = tmp2 + tmp3
103
+ tmp5 = tmp2 < 0
104
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
105
+ tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936")
106
+ tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1)
107
+ tmp9 = tmp8.to(tl.int32)
108
+ tmp10 = tmp9.to(tl.int64)
109
+ tmp12 = tmp10 * tmp11
110
+ tl.debug_barrier()
111
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
112
+ ''', device_str='cuda')
113
+
114
+
115
+ async_compile.wait(globals())
116
+ del async_compile
117
+
118
+ class Runner:
119
+ def __init__(self, partitions):
120
+ self.partitions = partitions
121
+
122
+ def recursively_apply_fns(self, fns):
123
+ new_callables = []
124
+ for fn, c in zip(fns, self.partitions):
125
+ new_callables.append(fn(c))
126
+ self.partitions = new_callables
127
+
128
+ def call(self, args):
129
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
130
+ args.clear()
131
+ s24 = arg0_1
132
+ arg1_1_size = arg1_1.size()
133
+ s14 = arg1_1_size[1]
134
+ assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1))
135
+ assert_size_stride(arg2_1, (151936, ), (1, ))
136
+ assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1))
137
+ with torch.cuda._DeviceGuard(0):
138
+ torch.cuda.set_device(0)
139
+ buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64)
140
+ buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse
141
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
142
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14
143
+ stream0 = get_raw_stream(0)
144
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream0)
145
+ del arg1_1
146
+ del arg2_1
147
+ del arg3_1
148
+ return (buf1, )
149
+
150
+ runner = Runner(partitions=[])
151
+ call = runner.call
152
+ recursively_apply_fns = runner.recursively_apply_fns
153
+
154
+
155
+ def benchmark_compiled_module(times=10, repeat=10):
156
+ from torch._dynamo.testing import rand_strided
157
+ from torch._inductor.utils import print_performance
158
+ arg0_1 = 1130
159
+ arg1_1 = rand_strided((2, 1130, 151936), (171687680, 151936, 1), device='cuda:0', dtype=torch.bfloat16)
160
+ arg2_1 = rand_strided((151936, ), (1, ), device='cuda:0', dtype=torch.bool)
161
+ arg3_1 = rand_strided((2, 1130, 1), (1130, 1, 1), device='cuda:0', dtype=torch.int64)
162
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
163
+ return print_performance(fn, times=times, repeat=repeat)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ from torch._inductor.wrapper_benchmark import compiled_module_main
168
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/3x/88057732cb1d7a775c254455fe42105016cd2d1ced3af1bd1fb079691b5972a1.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"}
SpecForge-ext/cache/compiled_kernels/3x/c3xxifdzdkxpgs4yujq2dd2lwhcylvzgstk2hao7g2yjirkxlafr.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 128, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 128
20
+ r0_numel = 16
21
+ R0_BLOCK: tl.constexpr = 16
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
35
+ tmp1 = tl.full([1, 1], 0, tl.int64)
36
+ tmp2 = tmp0 > tmp1
37
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
38
+ tmp4 = tmp0 < tmp3
39
+ tmp5 = tmp2 & tmp4
40
+ tmp6 = tmp5.to(tl.int8)
41
+ tmp7 = tmp6.to(tl.int32)
42
+ tmp8 = r0_1
43
+ tmp9 = tmp8.to(tl.int16)
44
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
45
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
46
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
47
+ tmp14 = tmp0 == tmp3
48
+ tmp15 = tmp14.to(tl.int8)
49
+ tmp16 = tmp15.to(tl.int32)
50
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
51
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
52
+ tmp20 = tmp7.to(tl.int64)
53
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
54
+ tmp23 = tl.where(xmask, tmp21, 0)
55
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
56
+ tmp25 = tmp16.to(tl.int64)
57
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
58
+ tmp28 = tl.where(xmask, tmp26, 0)
59
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
60
+ tmp30 = tmp24.to(tl.int32)
61
+ tmp31 = tmp29.to(tl.int32)
62
+ tmp32 = tmp13.to(tl.int64)
63
+ tmp33 = tmp32.to(tl.int32)
64
+ tmp34 = tmp8 < tmp30
65
+ tmp35 = tl.full([1, 1], 16, tl.int32)
66
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
67
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
68
+ tmp38 = tmp36 + tmp37
69
+ tmp39 = tmp36 < 0
70
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
71
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
72
+ tmp42 = tl.full([1, 1], 1, tl.int32)
73
+ tmp43 = tmp19.to(tl.int64)
74
+ tmp44 = tmp43.to(tl.int32)
75
+ tmp45 = tmp8 < tmp31
76
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
77
+ tmp47 = tmp46 + tmp37
78
+ tmp48 = tmp46 < 0
79
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
80
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
81
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
82
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
83
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
84
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
85
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
86
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
29
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
39
+ _tmp2, _tmp2_index, tmp1, rindex
40
+ )
41
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
42
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
43
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
44
+ tmp2 = tmp2_idx[:, None]
45
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/42/c42olsblh7ymaib2tr5gwhfzuighing5bkpmabq5hx7nxumtbsig.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/44/0525feab4902a63d7e5c68635e4f503cce07497c146c803e2f852ae21bd67e9c.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"}
SpecForge-ext/cache/compiled_kernels/44/c444sn4254wny52itl5mlassuxuptb3kc3p6r3r3stsx4lyt6t3r.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['8_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py
38
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
39
+ # Source node to ATen node mapping:
40
+ # dense_mask_2 => full_default_1
41
+ # Graph fragment:
42
+ # %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
43
+ # return %index_put
44
+ triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', '''
45
+ import triton
46
+ import triton.language as tl
47
+
48
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
49
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
50
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
51
+ triton_helpers.set_driver_to_gpu()
52
+
53
+ @triton_heuristics.pointwise(
54
+ size_hints={'x': 8192},
55
+ filename=__file__,
56
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
57
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
58
+ min_elem_per_thread=0
59
+ )
60
+ @triton.jit
61
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
62
+ xoffset = tl.program_id(0) * XBLOCK
63
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
64
+ xmask = xindex < xnumel
65
+ x0 = xindex
66
+ tmp0 = tl.full([1], 0, tl.int32)
67
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
68
+ ''', device_str='cuda')
69
+
70
+
71
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3g/c3gmll5f74ypurxotx73fmzfaldqb5oaua4nvjazcxzzjafvjoo2.py
72
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy]
73
+ # Source node to ATen node mapping:
74
+ # and_2 => bitwise_and_1
75
+ # and_3 => bitwise_and_2
76
+ # and_4 => bitwise_and_3, view_8
77
+ # b => iota
78
+ # batched_outputs_2 => view_9
79
+ # causal_mask => ge_1, view
80
+ # dense_mask => convert_element_type_2
81
+ # dense_mask_1 => convert_element_type_5
82
+ # diagnol_mask => eq_12
83
+ # full_blocks => eq_24
84
+ # full_blocks_1 => convert_element_type_1
85
+ # gt => gt
86
+ # index => index
87
+ # index_1 => index_1
88
+ # index_2 => index_2
89
+ # lt => lt, view_1
90
+ # lt_1 => lt_1, view_2
91
+ # lt_3 => lt_3
92
+ # m => iota_2
93
+ # mask_1 => constant_pad_nd
94
+ # mask_2 => view_10
95
+ # mask_3 => permute
96
+ # mask_block_sum => sum_1
97
+ # n => iota_3
98
+ # padding_mask => bitwise_and, view_3, view_4
99
+ # padding_mask_1 => lt_2, view_6
100
+ # partial_blocks => bitwise_and_4
101
+ # partial_blocks_1 => convert_element_type
102
+ # remainder => remainder
103
+ # remainder_1 => remainder_1
104
+ # result_1 => bitwise_or, full_default
105
+ # result_2 => bitwise_or_1
106
+ # sub => sub_12, view_7
107
+ # suffix_mask => ge_2
108
+ # Graph fragment:
109
+ # %arg1_1 : Tensor "i64[8][1]cuda:1" = PlaceHolder[target=arg1_1]
110
+ # %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=sum_1]
111
+ # %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:1, pin_memory: False})
112
+ # %iota_2 : Tensor "i64[2048][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
113
+ # %view : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
114
+ # %iota_3 : Tensor "i64[s37][1]cuda:1"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
115
+ # %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {})
116
+ # %iota : Tensor "i64[8][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
117
+ # %index : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {})
118
+ # %view_1 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {})
119
+ # %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {})
120
+ # %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg0_1]), kwargs = {})
121
+ # %index_1 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {})
122
+ # %view_2 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {})
123
+ # %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {})
124
+ # %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {})
125
+ # %bitwise_and : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {})
126
+ # %bitwise_and_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {})
127
+ # %bitwise_or : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {})
128
+ # %ge_2 : Tensor "b8[s37][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {})
129
+ # %remainder : Tensor "i64[s37][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {})
130
+ # %index_2 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {})
131
+ # %view_6 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {})
132
+ # %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {})
133
+ # %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {})
134
+ # %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg0_1]), kwargs = {})
135
+ # %view_7 : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
136
+ # %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {})
137
+ # %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {})
138
+ # %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {})
139
+ # %bitwise_and_3 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {})
140
+ # %bitwise_or_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {})
141
+ # %view_9 : Tensor "b8[8, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, %arg0_1]), kwargs = {})
142
+ # %constant_pad_nd : Tensor "b8[8, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {})
143
+ # %view_10 : Tensor "b8[8, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, 16, 128, %floordiv_1, 128]), kwargs = {})
144
+ # %permute : Tensor "b8[8, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {})
145
+ # %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {})
146
+ # %gt : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
147
+ # %lt_3 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {})
148
+ # %bitwise_and_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {})
149
+ # %convert_element_type : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {})
150
+ # %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {})
151
+ # %eq_24 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {})
152
+ # %convert_element_type_1 : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {})
153
+ # %convert_element_type_5 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {})
154
+ # return %sum_1,%convert_element_type_2,%convert_element_type_5
155
+ triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', '''
156
+ import triton
157
+ import triton.language as tl
158
+
159
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
160
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
161
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
162
+ triton_helpers.set_driver_to_gpu()
163
+
164
+ @triton_heuristics.reduction(
165
+ size_hints={'x': 4096, 'r0_': 16384},
166
+ reduction_hint=ReductionHint.INNER,
167
+ filename=__file__,
168
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
169
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
170
+ )
171
+ @triton.jit
172
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
173
+ r0_numel = 16384
174
+ rnumel = r0_numel
175
+ RBLOCK: tl.constexpr = R0_BLOCK
176
+ xoffset = tl.program_id(0) * XBLOCK
177
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
178
+ xmask = xindex < xnumel
179
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
180
+ rbase = r0_base
181
+ x0 = (xindex % ks0)
182
+ x1 = ((xindex // ks0) % 16)
183
+ x2 = xindex // ks2
184
+ _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
185
+ x5 = xindex
186
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
187
+ r0_index = r0_offset + r0_base
188
+ r0_mask = r0_index < r0_numel
189
+ roffset = r0_offset
190
+ rindex = r0_index
191
+ r0_3 = (r0_index % 128)
192
+ r0_4 = r0_index // 128
193
+ tmp0 = r0_3 + 128*x0
194
+ tmp1 = ks1
195
+ tmp2 = tmp0 < tmp1
196
+ tmp3 = r0_4 + 128*x1
197
+ tmp4 = r0_3 + 128*x0
198
+ tmp5 = tmp3 >= tmp4
199
+ tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
200
+ tmp7 = tmp4 < tmp6
201
+ tmp8 = tmp3 < tmp6
202
+ tmp9 = tmp7 & tmp8
203
+ tmp10 = tmp5 & tmp9
204
+ tmp11 = tl.full([1, 1], False, tl.int1)
205
+ tmp12 = tmp11 | tmp10
206
+ tmp13 = tl.full([1, 1], 2048, tl.int64)
207
+ tmp14 = tmp4 >= tmp13
208
+ tmp15 = ((r0_3 + 128*x0) % 2048)
209
+ tmp16 = tmp15 < tmp6
210
+ tmp17 = tmp14 & tmp16
211
+ tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
212
+ tmp19 = (tmp18 % tmp13)
213
+ tmp20 = tl.full([1, 1], 0, tl.int32)
214
+ tmp21 = tmp19 != tmp20
215
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
216
+ tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
217
+ tmp24 = tmp22 != tmp23
218
+ tmp25 = tmp21 & tmp24
219
+ tmp26 = tmp19 + tmp13
220
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
221
+ tmp28 = tl.full([1, 1], 0, tl.int64)
222
+ tmp29 = tmp27 == tmp28
223
+ tmp30 = tmp17 & tmp29
224
+ tmp31 = tmp12 | tmp30
225
+ tmp32 = tl.full(tmp31.shape, False, tmp31.dtype)
226
+ tmp33 = tl.where(tmp2, tmp31, tmp32)
227
+ tmp34 = tmp33.to(tl.int64)
228
+ tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK])
229
+ tmp37 = _tmp36 + tmp35
230
+ _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36)
231
+ tmp36 = tl.sum(_tmp36, 1)[:, None]
232
+ tmp38 = tl.full([1, 1], 0, tl.int64)
233
+ tmp39 = tmp36 > tmp38
234
+ tmp40 = tl.full([1, 1], 16384, tl.int64)
235
+ tmp41 = tmp36 < tmp40
236
+ tmp42 = tmp39 & tmp41
237
+ tmp43 = tmp42.to(tl.int8)
238
+ tmp44 = tmp43.to(tl.int32)
239
+ tmp45 = tmp36 == tmp40
240
+ tmp46 = tmp45.to(tl.int8)
241
+ tmp47 = tmp46.to(tl.int32)
242
+ tl.store(out_ptr1 + (x5), tmp44, xmask)
243
+ tl.store(out_ptr2 + (x5), tmp47, xmask)
244
+ ''', device_str='cuda')
245
+
246
+
247
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rf/crf75ojmgx3s35d4vq6bm6ahr7jskra4xhldkhobbo2elpsvqhja.py
248
+ # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
249
+ # Source node to ATen node mapping:
250
+ # arange_4 => iota_4
251
+ # child_3 => convert_element_type_3
252
+ # child_4 => convert_element_type_4
253
+ # col_range => iota_5
254
+ # dense_mask_2 => full_default_1
255
+ # index_mask => lt_4
256
+ # num_blocks_in_row => sum_2
257
+ # row_indices => unsqueeze
258
+ # setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6
259
+ # unsqueeze_1 => unsqueeze_1
260
+ # valid_indices => scalar_tensor, where
261
+ # Graph fragment:
262
+ # %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=convert_element_type_2]
263
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_2]
264
+ # %getitem_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 128*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=getitem_1]
265
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_3]
266
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=convert_element_type_4]
267
+ # %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:1" = PlaceHolder[target=index_put]
268
+ # %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
269
+ # %iota_7 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
270
+ # %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {})
271
+ # %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {})
272
+ # %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {})
273
+ # %iota_6 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
274
+ # %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {})
275
+ # %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {})
276
+ # %iota_4 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False})
277
+ # %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {})
278
+ # %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False})
279
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {})
280
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {})
281
+ # %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {})
282
+ # %lt_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {})
283
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {})
284
+ # %scalar_tensor : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1})
285
+ # %where : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {})
286
+ # %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
287
+ # %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {})
288
+ # return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13
289
+ triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', '''
290
+ import triton
291
+ import triton.language as tl
292
+
293
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
294
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
295
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
296
+ triton_helpers.set_driver_to_gpu()
297
+
298
+ @triton_heuristics.reduction(
299
+ size_hints={'x': 128, 'r0_': 32},
300
+ reduction_hint=ReductionHint.INNER,
301
+ filename=__file__,
302
+ triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
303
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
304
+ )
305
+ @triton.jit
306
+ def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
307
+ xnumel = 128
308
+ rnumel = r0_numel
309
+ RBLOCK: tl.constexpr = R0_BLOCK
310
+ xoffset = tl.program_id(0) * XBLOCK
311
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
312
+ xmask = xindex < xnumel
313
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
314
+ rbase = r0_base
315
+ x0 = xindex
316
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
317
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
318
+ r0_index = r0_offset + r0_base
319
+ r0_mask = r0_index < r0_numel
320
+ roffset = r0_offset
321
+ rindex = r0_index
322
+ r0_1 = r0_index
323
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
324
+ tmp1 = tmp0.to(tl.int64)
325
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
326
+ tmp4 = _tmp3 + tmp2
327
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
328
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
329
+ tmp5 = tmp3.to(tl.int32)
330
+ tl.store(out_ptr1 + (x0), tmp5, xmask)
331
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
332
+ r0_index = r0_offset + r0_base
333
+ r0_mask = r0_index < r0_numel
334
+ roffset = r0_offset
335
+ rindex = r0_index
336
+ r0_1 = r0_index
337
+ tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
338
+ tmp7 = tmp6.to(tl.int32)
339
+ tmp8 = r0_1
340
+ tmp9 = tmp8 < tmp5
341
+ tmp10 = ks0
342
+ tmp11 = tl.where(tmp9, tmp7, tmp10)
343
+ tmp12 = 1 + ks0
344
+ tmp13 = tmp11 + tmp12
345
+ tmp14 = tmp11 < 0
346
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
347
+ tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))")
348
+ tmp17 = tl.full([1, 1], 1, tl.int32)
349
+ tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask)
350
+ tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask)
351
+ ''', device_str='cuda')
352
+
353
+
354
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gn/cgnsrigp6qu2lbqq76g27kshvt2bzkyjnupza5ds7znhjxrnwhif.py
355
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
356
+ # Source node to ATen node mapping:
357
+ # batched_outputs_3 => clone_4, slice_2
358
+ # col_indices_2 => sort_2
359
+ # num_blocks_in_row_2 => sum_4
360
+ # q_indices => clone_6, convert_element_type_9
361
+ # q_num_blocks => convert_element_type_8
362
+ # transpose => permute_1
363
+ # Graph fragment:
364
+ # %buf13 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:1" = PlaceHolder[target=buf13]
365
+ # %buf15 : Tensor "i16[8, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), 16, 1]cuda:1" = PlaceHolder[target=buf15]
366
+ # %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=sum_4]
367
+ # %slice_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {})
368
+ # %clone_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format})
369
+ # %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {})
370
+ # %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True})
371
+ # %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {})
372
+ # %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format})
373
+ # %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {})
374
+ # %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {})
375
+ # return %buf15,%sum_4,%clone_6,%convert_element_type_8
376
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', '''
377
+ import triton
378
+ import triton.language as tl
379
+
380
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
381
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
382
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
383
+ triton_helpers.set_driver_to_gpu()
384
+
385
+ @triton_heuristics.persistent_reduction(
386
+ size_hints={'x': 256, 'r0_': 16},
387
+ reduction_hint=ReductionHint.DEFAULT,
388
+ filename=__file__,
389
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
390
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
391
+ )
392
+ @triton.jit
393
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
394
+ r0_numel = 16
395
+ R0_BLOCK: tl.constexpr = 16
396
+ rnumel = r0_numel
397
+ RBLOCK: tl.constexpr = R0_BLOCK
398
+ xoffset = tl.program_id(0) * XBLOCK
399
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
400
+ xmask = xindex < xnumel
401
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
402
+ r0_offset = 0
403
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
404
+ roffset = r0_offset
405
+ rindex = r0_index
406
+ r0_2 = r0_index
407
+ x0 = (xindex % ks0)
408
+ x1 = xindex // ks0
409
+ x3 = xindex
410
+ tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0)
411
+ tmp1 = r0_2
412
+ tmp2 = tmp1.to(tl.int16)
413
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
414
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
415
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
416
+ tmp7 = tmp0.to(tl.int64)
417
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
418
+ tmp10 = tl.where(xmask, tmp8, 0)
419
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
420
+ tmp12 = tmp6.to(tl.int64)
421
+ tmp13 = tmp12.to(tl.int32)
422
+ tmp14 = tmp11.to(tl.int32)
423
+ tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask)
424
+ tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask)
425
+ ''', device_str='cuda')
426
+
427
+
428
+ async_compile.wait(globals())
429
+ del async_compile
430
+
431
+ class Runner:
432
+ def __init__(self, partitions):
433
+ self.partitions = partitions
434
+
435
+ def recursively_apply_fns(self, fns):
436
+ new_callables = []
437
+ for fn, c in zip(fns, self.partitions):
438
+ new_callables.append(fn(c))
439
+ self.partitions = new_callables
440
+
441
+ def call(self, args):
442
+ arg0_1, arg1_1 = args
443
+ args.clear()
444
+ s37 = arg0_1
445
+ assert_size_stride(arg1_1, (8, ), (1, ))
446
+ with torch.cuda._DeviceGuard(1):
447
+ torch.cuda.set_device(1)
448
+ buf12 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32)
449
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
450
+ triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128)
451
+ stream1 = get_raw_stream(1)
452
+ triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream1)
453
+ buf19 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32)
454
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
455
+ triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128)
456
+ stream1 = get_raw_stream(1)
457
+ triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream1)
458
+ ps0 = (127 + s37) // 128
459
+ ps1 = 16*((127 + s37) // 128)
460
+ buf1 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32)
461
+ buf5 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32)
462
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy]
463
+ triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 128*((127 + s37) // 128)
464
+ stream1 = get_raw_stream(1)
465
+ triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream1)
466
+ del arg1_1
467
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort]
468
+ buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True)
469
+ buf4 = buf2[1]
470
+ assert_size_stride(buf4, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable')
471
+ assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable')
472
+ del buf2
473
+ buf10 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
474
+ buf11 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
475
+ # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
476
+ triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128
477
+ stream1 = get_raw_stream(1)
478
+ triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream1)
479
+ del buf1
480
+ del buf4
481
+ buf26 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32)
482
+ buf28 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
483
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
484
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128)
485
+ stream1 = get_raw_stream(1)
486
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream1)
487
+ del buf12
488
+ # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort]
489
+ buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True)
490
+ buf8 = buf6[1]
491
+ assert_size_stride(buf8, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable')
492
+ assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable')
493
+ del buf6
494
+ buf17 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
495
+ buf18 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
496
+ # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
497
+ triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128
498
+ stream1 = get_raw_stream(1)
499
+ triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream1)
500
+ del buf5
501
+ del buf8
502
+ buf23 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32)
503
+ buf25 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
504
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
505
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128)
506
+ stream1 = get_raw_stream(1)
507
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream1)
508
+ del buf19
509
+ return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, )
510
+
511
+ runner = Runner(partitions=[])
512
+ call = runner.call
513
+ recursively_apply_fns = runner.recursively_apply_fns
514
+
515
+
516
+ def benchmark_compiled_module(times=10, repeat=10):
517
+ from torch._dynamo.testing import rand_strided
518
+ from torch._inductor.utils import print_performance
519
+ arg0_1 = 4096
520
+ arg1_1 = rand_strided((8, ), (1, ), device='cuda:1', dtype=torch.int64)
521
+ fn = lambda: call([arg0_1, arg1_1])
522
+ return print_performance(fn, times=times, repeat=repeat)
523
+
524
+
525
+ if __name__ == "__main__":
526
+ from torch._inductor.wrapper_benchmark import compiled_module_main
527
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/47/c477sca7o5mbhj2pknepbw5b3rzush4uzefidcyxm6ysescvabgf.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 4096
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = (xindex % 2048)
29
+ x1 = xindex // 2048
30
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
31
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
32
+ x3 = xindex
33
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_2 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0)
40
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
41
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
42
+ _tmp2, _tmp2_index, tmp1, rindex
43
+ )
44
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
45
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
46
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
47
+ tmp2 = tmp2_idx[:, None]
48
+ tl.store(out_ptr0 + (x3), tmp2, None)
SpecForge-ext/cache/compiled_kernels/47/c47ib4wwcplmaudqv7246conlxcjjylxi5ahlxye6ebv6onrgoxg.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
34
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
35
+ tmp3 = tl.where(xmask, tmp1, 0)
36
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
37
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 4096
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_1 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
38
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
40
+ _tmp2, _tmp2_index, tmp1, rindex
41
+ )
42
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
43
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
44
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
45
+ tmp2 = tmp2_idx[:, None]
46
+ tl.store(out_ptr0 + (x0), tmp2, None)
SpecForge-ext/cache/compiled_kernels/4d/71d6529f8e555bde19385922da8b4def675c0510b5b664b1bf4faebe9f8928eb.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "B46RWD5PEMKEQR7EBR6IG3BGTK4P7CWBVNOODNZQX5NAVXXVIH2A"}
SpecForge-ext/cache/compiled_kernels/4d/c4dlhtn4sdcx4l7dqawohb4u6hvu3xnzxvwazc5qrp6haqgnm3ev.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 2048},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.full([1], 0, tl.int32)
24
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/4d/c4dw6ykxgbwk3glacutxkpzwhapvr5oszjet3n4i4q3snumjzm3x.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = 2048
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 256
149
+ stride_kv_idx_m = 16
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 16
245
+ stride_q_idx_h = 256
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = True
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = 2048
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = True
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = True
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = 2048
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = True
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/57/c57svaeo74ael4oxqveudfvhx4xfmu3ikrmvljcnixb4kiqagrzn.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = xindex // ks0
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ x3 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_2 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
39
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
41
+ _tmp2, _tmp2_index, tmp1, rindex
42
+ )
43
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
44
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
45
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
46
+ tmp2 = tmp2_idx[:, None]
47
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/5b/c5bepmksmwsj5jf67nbhozcficyzxyuodtllwlc5wms64ubwgqh6.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 8
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = ks8
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = ks0
596
+ KV_LEN = ks1
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = ks8
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/5b/c5bjpwebwz42wx5vrxljhccvdqdyttlxd2hpbtzw7ia3oq6c33ne.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ r0_numel = 4096
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_0 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
38
+ tmp2 = tmp0 == tmp1
39
+ tmp3 = tmp2.to(tl.int64)
40
+ tmp5 = tmp3 * tmp4
41
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
42
+ tmp8 = _tmp7 + tmp6
43
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
44
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
45
+ _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
46
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
47
+ r0_index = r0_offset + r0_base
48
+ r0_mask = r0_index < r0_numel
49
+ roffset = r0_offset
50
+ rindex = r0_index
51
+ r0_0 = r0_index
52
+ tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
53
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
54
+ tmp12 = _tmp11 + tmp10
55
+ _tmp11 = tl.where(r0_mask, tmp12, _tmp11)
56
+ tmp11 = tl.sum(_tmp11, 1)[:, None]
57
+ tmp13 = tmp7.to(tl.float32)
58
+ tmp14 = tmp11.to(tl.float32)
59
+ tmp15 = 1e-06
60
+ tmp16 = triton_helpers.maximum(tmp14, tmp15)
61
+ tmp17 = (tmp13 / tmp16)
62
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None)
SpecForge-ext/cache/compiled_kernels/5g/c5g26r4ygcctmxuptx453t3kikkqukh73touvd4yxv7futs36kgf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/5g/c5g4egnryommgtc4braxeh3xxhfypsb6ec3v4sx3gxmfoeho5bzl.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['13_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/of/cofwz2ulo5xzqhau3cyhif5tweuyn7cqvg27usnkxh25zmnsmxqm.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[8, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 524288, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ r0_numel = 128
66
+ rnumel = r0_numel
67
+ RBLOCK: tl.constexpr = R0_BLOCK
68
+ xoffset = tl.program_id(0) * XBLOCK
69
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
70
+ xmask = xindex < xnumel
71
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
72
+ rbase = r0_base
73
+ x0 = (xindex % ks0)
74
+ x1 = ((xindex // ks0) % 32)
75
+ x2 = xindex // ks1
76
+ x5 = triton_helpers.div_floor_integer(xindex, ks0)
77
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
78
+ x4 = xindex
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, xmask)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2]
104
+ # %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_4]
105
+ # %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_6]
106
+ # %getitem_1 : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=getitem_5]
111
+ # %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13]
112
+ # %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9]
113
+ # %primals_22 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:0" = PlaceHolder[target=primals_22]
114
+ # %primals_25 : Tensor "i32[8, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:0" = PlaceHolder[target=primals_25]
115
+ # %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=primals_17]
116
+ # %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_20]
117
+ # %primals_27 : Tensor "i32[8, 1, s100][s100, s100, 1]cuda:0" = PlaceHolder[target=primals_27]
118
+ # %primals_30 : Tensor "i32[8, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:0" = PlaceHolder[target=primals_30]
119
+ # %primals_14 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=primals_14]
120
+ # %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = False
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
223
+
224
+ ZQ = 8
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = ks0
228
+ ZKV = 8
229
+ KV_LEN = ks1
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 8
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = ks2
270
+ stride_kv_idx_h = ks3*ks4
271
+ stride_kv_idx_m = ks4
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = ks5
367
+ stride_q_idx_h = ks6*ks7
368
+ stride_q_idx_n = ks6
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = False
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = ks0
507
+ KV_LEN = ks1
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = False
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = ks8
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = False
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = ks0
718
+ KV_LEN = ks1
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = False
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = ks8
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ s37 = primals_10
978
+ s0 = primals_11
979
+ s75 = primals_15
980
+ s22 = primals_7
981
+ s72 = primals_8
982
+ s99 = primals_12
983
+ s94 = primals_16
984
+ s28 = primals_18
985
+ s4 = primals_19
986
+ s56 = primals_21
987
+ s53 = primals_24
988
+ s84 = primals_23
989
+ s100 = primals_26
990
+ s10 = primals_29
991
+ s6 = primals_28
992
+ assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1))
993
+ assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
994
+ assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
995
+ assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
996
+ assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1))
997
+ assert_size_stride(primals_14, (8, ), (1, ))
998
+ assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1))
999
+ assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
1000
+ assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1))
1001
+ assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
1002
+ assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1))
1003
+ assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1))
1004
+ assert_size_stride(getitem, (8, 32, s37, 128), (4096*s37, 128, 4096, 1))
1005
+ assert_size_stride(getitem_1, (8, 32, s37), (32*max(1, s37), max(1, s37), 1))
1006
+ assert_size_stride(tangents_1, (8, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
1007
+ with torch.cuda._DeviceGuard(0):
1008
+ torch.cuda.set_device(0)
1009
+ ps0 = 32*s37
1010
+ buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32)
1011
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1012
+ triton_red_fused_zeros_0_xnumel = 256*s37
1013
+ stream0 = get_raw_stream(0)
1014
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream0)
1015
+ del getitem
1016
+ buf3 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
1017
+ buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1018
+ buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1019
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1020
+ stream0 = get_raw_stream(0)
1021
+ triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 8, 8, stream=stream0)
1022
+ del buf1
1023
+ del getitem_1
1024
+ del primals_13
1025
+ del primals_14
1026
+ del primals_17
1027
+ del primals_2
1028
+ del primals_20
1029
+ del primals_22
1030
+ del primals_25
1031
+ del primals_27
1032
+ del primals_30
1033
+ del primals_4
1034
+ del primals_6
1035
+ del primals_9
1036
+ del tangents_1
1037
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
1038
+
1039
+ runner = Runner(partitions=[])
1040
+ call = runner.call
1041
+ recursively_apply_fns = runner.recursively_apply_fns
1042
+
1043
+
1044
+ def benchmark_compiled_module(times=10, repeat=10):
1045
+ from torch._dynamo.testing import rand_strided
1046
+ from torch._inductor.utils import print_performance
1047
+ primals_10 = 2009
1048
+ primals_11 = 2009
1049
+ primals_15 = 2009
1050
+ primals_7 = 16
1051
+ primals_8 = 16
1052
+ primals_12 = 16
1053
+ primals_16 = 16
1054
+ primals_18 = 16
1055
+ primals_19 = 16
1056
+ primals_21 = 16
1057
+ primals_24 = 16
1058
+ primals_23 = 16
1059
+ primals_26 = 16
1060
+ primals_29 = 16
1061
+ primals_28 = 16
1062
+ primals_2 = rand_strided((8, 32, 2009, 128), (8228864, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
1063
+ primals_4 = rand_strided((8, 8, 2009, 128), (2057216, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16)
1064
+ primals_6 = rand_strided((8, 8, 2009, 128), (2057216, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16)
1065
+ primals_9 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
1066
+ primals_13 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
1067
+ primals_14 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64)
1068
+ primals_17 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
1069
+ primals_20 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
1070
+ primals_22 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
1071
+ primals_25 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
1072
+ primals_27 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
1073
+ primals_30 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
1074
+ getitem = rand_strided((8, 32, 2009, 128), (8228864, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
1075
+ getitem_1 = rand_strided((8, 32, 2009), (64288, 2009, 1), device='cuda:0', dtype=torch.float32)
1076
+ tangents_1 = rand_strided((8, 32, 2009, 128), (8228864, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16)
1077
+ fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1])
1078
+ return print_performance(fn, times=times, repeat=repeat)
1079
+
1080
+
1081
+ if __name__ == "__main__":
1082
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1083
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/5g/f7b122e9e44d29c2d695b9633c17bd6eee619900dd2325a29224a34ca8164da3.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"}
SpecForge-ext/cache/compiled_kernels/5j/c5jkwkcyjhbh7wkjfrwxf42i764iqjvwig2sk7me5dmfpiuqgldn.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['7_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xg/cxgas2jt2e7frwpm3nd5h7wlzdo2fb2yonkvkfoarpyafiwosg23.py
38
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
39
+ # Source node to ATen node mapping:
40
+ # argmax => argmax
41
+ # Graph fragment:
42
+ # %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:3" = PlaceHolder[target=arg0_1]
43
+ # %argmax : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {})
44
+ # return %argmax
45
+ triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', '''
46
+ import triton
47
+ import triton.language as tl
48
+
49
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
50
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
51
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
52
+ triton_helpers.set_driver_to_gpu()
53
+
54
+ @triton_heuristics.reduction(
55
+ size_hints={'x': 16384, 'r0_': 32768},
56
+ reduction_hint=ReductionHint.INNER,
57
+ filename=__file__,
58
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
59
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}}
60
+ )
61
+ @triton.jit
62
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
63
+ xnumel = 16384
64
+ r0_numel = 32000
65
+ rnumel = r0_numel
66
+ RBLOCK: tl.constexpr = R0_BLOCK
67
+ xoffset = tl.program_id(0) * XBLOCK
68
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
69
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
70
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
71
+ rbase = r0_base
72
+ x0 = xindex
73
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
74
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
75
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
76
+ r0_index = r0_offset + r0_base
77
+ r0_mask = r0_index < r0_numel
78
+ roffset = r0_offset
79
+ rindex = r0_index
80
+ r0_1 = r0_index
81
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
82
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
83
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
84
+ _tmp2, _tmp2_index, tmp1, rindex
85
+ )
86
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
87
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
88
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
89
+ tmp2 = tmp2_idx[:, None]
90
+ tl.store(out_ptr0 + (x0), tmp2, None)
91
+ ''', device_str='cuda')
92
+
93
+
94
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fv/cfvm655j5cm4524gmdyhr7yli6dffpakysuycuozkqmyuaonkwbg.py
95
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
96
+ # Source node to ATen node mapping:
97
+ # argmax_1 => argmax_1
98
+ # Graph fragment:
99
+ # %arg1_1 : Tensor "f32[8, 2048, 32000][65760000, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1]
100
+ # %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
101
+ # return %argmax_1
102
+ triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', '''
103
+ import triton
104
+ import triton.language as tl
105
+
106
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
107
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
108
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
109
+ triton_helpers.set_driver_to_gpu()
110
+
111
+ @triton_heuristics.reduction(
112
+ size_hints={'x': 16384, 'r0_': 32768},
113
+ reduction_hint=ReductionHint.DEFAULT,
114
+ filename=__file__,
115
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
116
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}}
117
+ )
118
+ @triton.jit
119
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
120
+ xnumel = 16384
121
+ r0_numel = 32000
122
+ rnumel = r0_numel
123
+ RBLOCK: tl.constexpr = R0_BLOCK
124
+ xoffset = tl.program_id(0) * XBLOCK
125
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
126
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
127
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
128
+ rbase = r0_base
129
+ x0 = (xindex % 2048)
130
+ x1 = xindex // 2048
131
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
132
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
133
+ x3 = xindex
134
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
135
+ r0_index = r0_offset + r0_base
136
+ r0_mask = r0_index < r0_numel
137
+ roffset = r0_offset
138
+ rindex = r0_index
139
+ r0_2 = r0_index
140
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0)
141
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
142
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
143
+ _tmp2, _tmp2_index, tmp1, rindex
144
+ )
145
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
146
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
147
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
148
+ tmp2 = tmp2_idx[:, None]
149
+ tl.store(out_ptr0 + (x3), tmp2, None)
150
+ ''', device_str='cuda')
151
+
152
+
153
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/tz/ctz36xgzfd5jcgzrek7dztousbfkxkdiqxeixt3t36guolxobku7.py
154
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
155
+ # Source node to ATen node mapping:
156
+ # eq => eq
157
+ # mul => mul
158
+ # squeeze => squeeze
159
+ # sum_1 => sum_1
160
+ # Graph fragment:
161
+ # %argmax : Tensor "i64[8, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax]
162
+ # %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax_1]
163
+ # %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg2_1]
164
+ # %eq : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
165
+ # %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {})
166
+ # %mul : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {})
167
+ # %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
168
+ # return %buf3
169
+ triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', '''
170
+ import triton
171
+ import triton.language as tl
172
+
173
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
174
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
175
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
176
+ triton_helpers.set_driver_to_gpu()
177
+
178
+ @triton_heuristics.reduction(
179
+ size_hints={'x': 2, 'r0_': 8192},
180
+ reduction_hint=ReductionHint.INNER,
181
+ filename=__file__,
182
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
183
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}}
184
+ )
185
+ @triton.jit
186
+ def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
187
+ xnumel = 2
188
+ r0_numel = 8192
189
+ rnumel = r0_numel
190
+ RBLOCK: tl.constexpr = R0_BLOCK
191
+ xoffset = tl.program_id(0) * XBLOCK
192
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
193
+ xmask = xindex < xnumel
194
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
195
+ rbase = r0_base
196
+ x0 = xindex
197
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
198
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
199
+ r0_index = r0_offset + r0_base
200
+ r0_mask = r0_index < r0_numel
201
+ roffset = r0_offset
202
+ rindex = r0_index
203
+ r0_1 = r0_index
204
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
205
+ tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
206
+ tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
207
+ tmp2 = tmp0 == tmp1
208
+ tmp3 = tmp2.to(tl.int64)
209
+ tmp5 = tmp3 * tmp4
210
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
211
+ tmp8 = _tmp7 + tmp6
212
+ _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7)
213
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
214
+ tl.store(out_ptr0 + (x0), tmp7, xmask)
215
+ ''', device_str='cuda')
216
+
217
+
218
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/eu/ceuhopmcdleig6m43h7kk4fhghkl5w2umfjuyngxydc4pr3zpumg.py
219
+ # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum]
220
+ # Source node to ATen node mapping:
221
+ # sum_2 => sum_2
222
+ # Graph fragment:
223
+ # %arg3_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg3_1]
224
+ # %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {})
225
+ # return %buf5
226
+ triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', '''
227
+ import triton
228
+ import triton.language as tl
229
+
230
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
231
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
232
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
233
+ triton_helpers.set_driver_to_gpu()
234
+
235
+ @triton_heuristics.reduction(
236
+ size_hints={'x': 2, 'r0_': 8192},
237
+ reduction_hint=ReductionHint.INNER,
238
+ filename=__file__,
239
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
240
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}}
241
+ )
242
+ @triton.jit
243
+ def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
244
+ xnumel = 2
245
+ r0_numel = 8192
246
+ rnumel = r0_numel
247
+ RBLOCK: tl.constexpr = R0_BLOCK
248
+ xoffset = tl.program_id(0) * XBLOCK
249
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
250
+ xmask = xindex < xnumel
251
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
252
+ rbase = r0_base
253
+ x0 = xindex
254
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
255
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
256
+ r0_index = r0_offset + r0_base
257
+ r0_mask = r0_index < r0_numel
258
+ roffset = r0_offset
259
+ rindex = r0_index
260
+ r0_1 = r0_index
261
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
262
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
263
+ tmp3 = _tmp2 + tmp1
264
+ _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
265
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
266
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
267
+ ''', device_str='cuda')
268
+
269
+
270
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mw/cmw2hjnuubs2eh7cuc54pem6cjhaz4jgplmqlhrsxfkzljxf7ndg.py
271
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div]
272
+ # Source node to ATen node mapping:
273
+ # clamp_min => clamp_min
274
+ # eq => eq
275
+ # mul => mul
276
+ # squeeze => squeeze
277
+ # sum_1 => sum_1
278
+ # sum_2 => sum_2
279
+ # truediv => div
280
+ # Graph fragment:
281
+ # %buf3 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=buf3]
282
+ # %buf5 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=buf5]
283
+ # %sum_1 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_1]
284
+ # %sum_2 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_2]
285
+ # %eq : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
286
+ # %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {})
287
+ # %mul : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {})
288
+ # %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
289
+ # %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {})
290
+ # %clamp_min : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {})
291
+ # %div : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {})
292
+ # return %sum_1,%sum_2,%div
293
+ triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', '''
294
+ import triton
295
+ import triton.language as tl
296
+
297
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
298
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
299
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
300
+ triton_helpers.set_driver_to_gpu()
301
+
302
+ @triton_heuristics.persistent_reduction(
303
+ size_hints={'x': 1, 'r0_': 2},
304
+ reduction_hint=ReductionHint.INNER,
305
+ filename=__file__,
306
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
307
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}}
308
+ )
309
+ @triton.jit
310
+ def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr):
311
+ xnumel = 1
312
+ r0_numel = 2
313
+ R0_BLOCK: tl.constexpr = 2
314
+ rnumel = r0_numel
315
+ RBLOCK: tl.constexpr = R0_BLOCK
316
+ xoffset = tl.program_id(0) * XBLOCK
317
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
318
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
319
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
320
+ r0_offset = 0
321
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
322
+ roffset = r0_offset
323
+ rindex = r0_index
324
+ r0_0 = r0_index
325
+ tmp0 = tl.load(in_ptr0 + (r0_0), None)
326
+ tmp4 = tl.load(in_ptr1 + (r0_0), None)
327
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
328
+ tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64)
329
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
330
+ tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64)
331
+ tmp8 = tmp3.to(tl.float32)
332
+ tmp9 = tmp7.to(tl.float32)
333
+ tmp10 = 1e-06
334
+ tmp11 = triton_helpers.maximum(tmp9, tmp10)
335
+ tmp12 = (tmp8 / tmp11)
336
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None)
337
+ ''', device_str='cuda')
338
+
339
+
340
+ async_compile.wait(globals())
341
+ del async_compile
342
+
343
+ class Runner:
344
+ def __init__(self, partitions):
345
+ self.partitions = partitions
346
+
347
+ def recursively_apply_fns(self, fns):
348
+ new_callables = []
349
+ for fn, c in zip(fns, self.partitions):
350
+ new_callables.append(fn(c))
351
+ self.partitions = new_callables
352
+
353
+ def call(self, args):
354
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
355
+ args.clear()
356
+ assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1))
357
+ assert_size_stride(arg1_1, (8, 2048, 32000), (65760000, 32000, 1))
358
+ assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1))
359
+ assert_size_stride(arg3_1, (8, 2048, 1), (2048, 1, 1))
360
+ with torch.cuda._DeviceGuard(3):
361
+ torch.cuda.set_device(3)
362
+ buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64)
363
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
364
+ stream3 = get_raw_stream(3)
365
+ triton_red_fused_argmax_0.run(arg0_1, buf0, 16384, 32000, stream=stream3)
366
+ del arg0_1
367
+ buf1 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64)
368
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
369
+ stream3 = get_raw_stream(3)
370
+ triton_red_fused_argmax_1.run(arg1_1, buf1, 16384, 32000, stream=stream3)
371
+ del arg1_1
372
+ buf3 = empty_strided_cuda((2, ), (1, ), torch.int64)
373
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
374
+ stream3 = get_raw_stream(3)
375
+ triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, buf3, 2, 8192, stream=stream3)
376
+ del arg2_1
377
+ del buf0
378
+ del buf1
379
+ buf5 = empty_strided_cuda((2, ), (1, ), torch.int64)
380
+ # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum]
381
+ stream3 = get_raw_stream(3)
382
+ triton_red_fused_sum_3.run(arg3_1, buf5, 2, 8192, stream=stream3)
383
+ del arg3_1
384
+ buf7 = empty_strided_cuda((), (), torch.float32)
385
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div]
386
+ stream3 = get_raw_stream(3)
387
+ triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream3)
388
+ del buf3
389
+ del buf5
390
+ return (buf7, )
391
+
392
+ runner = Runner(partitions=[])
393
+ call = runner.call
394
+ recursively_apply_fns = runner.recursively_apply_fns
395
+
396
+
397
+ def benchmark_compiled_module(times=10, repeat=10):
398
+ from torch._dynamo.testing import rand_strided
399
+ from torch._inductor.utils import print_performance
400
+ arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:3', dtype=torch.bfloat16)
401
+ arg1_1 = rand_strided((8, 2048, 32000), (65760000, 32000, 1), device='cuda:3', dtype=torch.float32)
402
+ arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64)
403
+ arg3_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64)
404
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
405
+ return print_performance(fn, times=times, repeat=repeat)
406
+
407
+
408
+ if __name__ == "__main__":
409
+ from torch._inductor.wrapper_benchmark import compiled_module_main
410
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/5j/c5jvrblz5ym34kn4ssfnzfxabvx53ffrcnqlmwnjv3gmkqfkgo4v.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['13_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/s4/cs4afdu7ezaeekoshfdryoga6jabuq2nrx5xdkgxrehrkuvy5jri.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[2, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 131072, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ r0_numel = 128
66
+ rnumel = r0_numel
67
+ RBLOCK: tl.constexpr = R0_BLOCK
68
+ xoffset = tl.program_id(0) * XBLOCK
69
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
70
+ xmask = xindex < xnumel
71
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
72
+ rbase = r0_base
73
+ x0 = (xindex % ks0)
74
+ x1 = ((xindex // ks0) % 32)
75
+ x2 = xindex // ks1
76
+ x5 = triton_helpers.div_floor_integer(xindex, ks0)
77
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
78
+ x4 = xindex
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, xmask)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3n/c3nhszg76l7meq3aapcfdfxknr3a44zlaammv6ewmbubopxjxjqh.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2]
104
+ # %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_4]
105
+ # %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_6]
106
+ # %getitem_1 : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=getitem_5]
111
+ # %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_13]
112
+ # %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_9]
113
+ # %primals_22 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:3" = PlaceHolder[target=primals_22]
114
+ # %primals_25 : Tensor "i32[2, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:3" = PlaceHolder[target=primals_25]
115
+ # %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:3" = PlaceHolder[target=primals_17]
116
+ # %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_20]
117
+ # %primals_27 : Tensor "i32[2, 1, s100][s100, s100, 1]cuda:3" = PlaceHolder[target=primals_27]
118
+ # %primals_30 : Tensor "i32[2, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:3" = PlaceHolder[target=primals_30]
119
+ # %primals_14 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_14]
120
+ # %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = False
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
223
+
224
+ ZQ = 2
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = ks0
228
+ ZKV = 2
229
+ KV_LEN = ks1
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 2
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = ks2
270
+ stride_kv_idx_h = ks3*ks4
271
+ stride_kv_idx_m = ks4
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = ks5
367
+ stride_q_idx_h = ks6*ks7
368
+ stride_q_idx_n = ks6
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = False
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = ks0
507
+ KV_LEN = ks1
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = False
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = ks8
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = False
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = ks0
718
+ KV_LEN = ks1
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = False
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = ks8
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ s37 = primals_10
978
+ s0 = primals_11
979
+ s75 = primals_15
980
+ s22 = primals_7
981
+ s72 = primals_8
982
+ s99 = primals_12
983
+ s94 = primals_16
984
+ s28 = primals_18
985
+ s4 = primals_19
986
+ s56 = primals_21
987
+ s53 = primals_24
988
+ s84 = primals_23
989
+ s100 = primals_26
990
+ s10 = primals_29
991
+ s6 = primals_28
992
+ assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1))
993
+ assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
994
+ assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
995
+ assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
996
+ assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1))
997
+ assert_size_stride(primals_14, (2, ), (1, ))
998
+ assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1))
999
+ assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
1000
+ assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1))
1001
+ assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
1002
+ assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1))
1003
+ assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1))
1004
+ assert_size_stride(getitem, (2, 32, s37, 128), (4096*s37, 128, 4096, 1))
1005
+ assert_size_stride(getitem_1, (2, 32, s37), (32*max(1, s37), max(1, s37), 1))
1006
+ assert_size_stride(tangents_1, (2, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
1007
+ with torch.cuda._DeviceGuard(3):
1008
+ torch.cuda.set_device(3)
1009
+ ps0 = 32*s37
1010
+ buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32)
1011
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1012
+ triton_red_fused_zeros_0_xnumel = 64*s37
1013
+ stream3 = get_raw_stream(3)
1014
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream3)
1015
+ del getitem
1016
+ buf3 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
1017
+ buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1018
+ buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1019
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1020
+ stream3 = get_raw_stream(3)
1021
+ triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 2, 8, stream=stream3)
1022
+ del buf1
1023
+ del getitem_1
1024
+ del primals_13
1025
+ del primals_14
1026
+ del primals_17
1027
+ del primals_2
1028
+ del primals_20
1029
+ del primals_22
1030
+ del primals_25
1031
+ del primals_27
1032
+ del primals_30
1033
+ del primals_4
1034
+ del primals_6
1035
+ del primals_9
1036
+ del tangents_1
1037
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
1038
+
1039
+ runner = Runner(partitions=[])
1040
+ call = runner.call
1041
+ recursively_apply_fns = runner.recursively_apply_fns
1042
+
1043
+
1044
+ def benchmark_compiled_module(times=10, repeat=10):
1045
+ from torch._dynamo.testing import rand_strided
1046
+ from torch._inductor.utils import print_performance
1047
+ primals_10 = 2014
1048
+ primals_11 = 2014
1049
+ primals_15 = 2014
1050
+ primals_7 = 16
1051
+ primals_8 = 16
1052
+ primals_12 = 16
1053
+ primals_16 = 16
1054
+ primals_18 = 16
1055
+ primals_19 = 16
1056
+ primals_21 = 16
1057
+ primals_24 = 16
1058
+ primals_23 = 16
1059
+ primals_26 = 16
1060
+ primals_29 = 16
1061
+ primals_28 = 16
1062
+ primals_2 = rand_strided((2, 32, 2014, 128), (8249344, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
1063
+ primals_4 = rand_strided((2, 8, 2014, 128), (2062336, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16)
1064
+ primals_6 = rand_strided((2, 8, 2014, 128), (2062336, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16)
1065
+ primals_9 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
1066
+ primals_13 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
1067
+ primals_14 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64)
1068
+ primals_17 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
1069
+ primals_20 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
1070
+ primals_22 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
1071
+ primals_25 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
1072
+ primals_27 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
1073
+ primals_30 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
1074
+ getitem = rand_strided((2, 32, 2014, 128), (8249344, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
1075
+ getitem_1 = rand_strided((2, 32, 2014), (64448, 2014, 1), device='cuda:3', dtype=torch.float32)
1076
+ tangents_1 = rand_strided((2, 32, 2014, 128), (8249344, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16)
1077
+ fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1])
1078
+ return print_performance(fn, times=times, repeat=repeat)
1079
+
1080
+
1081
+ if __name__ == "__main__":
1082
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1083
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/5o/c5oswnr7dpwwcqp5m7thpaf4owvpseka6amdvdroewnorzsre6n2.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/5o/fd9062ce8e19c42a2ac9826803e021d98494b253c62ff7bfe753f34e0c863929.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 50, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"}
SpecForge-ext/cache/compiled_kernels/5r/67cf09099929c6923eb0884c24406ef33ddeccd796aae57f0484cb9e81164741.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"}
SpecForge-ext/cache/compiled_kernels/5r/a92927e5b439ed1b110bb08d838b1d456f558deda739d61f97499093c88a877a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "42NVHDOVRHC3TSIT2M6NVJU72L5EVVTGFXWS47GDCP2GM2XRN7KA"}
SpecForge-ext/cache/compiled_kernels/5r/c5rruemruzlohybkl4bagtqtk5athtuxf3eoj37rjptdltrrri3r.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4194304},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/5r/c5rs7ak7dbb5csdzszewryjtnxnhv7xpwdjvgqrsp5lfmmej4poi.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 256},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.full([1], 0, tl.int32)
24
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/5w/c5w735qbviioww7vfjj36tk57xo254oei3wqkunaiekkjd5pfcph.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 128, 'r0_': 16},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tmp0.to(tl.int64)
36
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
37
+ tmp4 = _tmp3 + tmp2
38
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
39
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
40
+ x2 = (xindex % ks1)
41
+ x3 = xindex // ks1
42
+ tmp5 = tmp3.to(tl.int32)
43
+ tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask)
SpecForge-ext/cache/compiled_kernels/5z/c5z5oj5ee2bvvg2pkzwf6smszdy73565nillm7gopvokmvrvu2dp.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['13_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2]
44
+ # %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_4]
45
+ # %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_6]
46
+ # %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1]
48
+ # %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13]
49
+ # %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9]
50
+ # %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_17]
51
+ # %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_20]
52
+ # %primals_14 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=primals_14]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = False
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = ks0
146
+ ZKV = 2
147
+ KV_LEN = ks1
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = ks2
184
+ stride_kv_idx_h = ks3*ks4
185
+ stride_kv_idx_m = ks4
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = False
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = ks5
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = False
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args
625
+ args.clear()
626
+ s50 = primals_1
627
+ s0 = primals_3
628
+ s43 = primals_5
629
+ s22 = primals_7
630
+ s72 = primals_8
631
+ s37 = primals_10
632
+ s71 = primals_11
633
+ s99 = primals_12
634
+ s75 = primals_15
635
+ s94 = primals_16
636
+ s28 = primals_18
637
+ s4 = primals_19
638
+ s56 = primals_21
639
+ s84 = primals_23
640
+ s53 = primals_24
641
+ s100 = primals_26
642
+ s6 = primals_28
643
+ s10 = primals_29
644
+ assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1))
645
+ assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
646
+ assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
647
+ assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
648
+ assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1))
649
+ assert_size_stride(primals_14, (2, ), (1, ))
650
+ assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1))
651
+ assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
652
+ assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1))
653
+ assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
654
+ assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1))
655
+ assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1))
656
+ with torch.cuda._DeviceGuard(4):
657
+ torch.cuda.set_device(4)
658
+ buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32)
659
+ buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32)
660
+ buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
661
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
662
+ stream4 = get_raw_stream(4)
663
+ triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream4)
664
+ del buf1
665
+ return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, )
666
+
667
+ runner = Runner(partitions=[])
668
+ call = runner.call
669
+ recursively_apply_fns = runner.recursively_apply_fns
670
+
671
+
672
+ def benchmark_compiled_module(times=10, repeat=10):
673
+ from torch._dynamo.testing import rand_strided
674
+ from torch._inductor.utils import print_performance
675
+ primals_1 = 1543
676
+ primals_2 = rand_strided((2, 32, 1543, 128), (6320128, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
677
+ primals_3 = 1543
678
+ primals_4 = rand_strided((2, 8, 1543, 128), (1580032, 197504, 128, 1), device='cuda:4', dtype=torch.bfloat16)
679
+ primals_5 = 1543
680
+ primals_6 = rand_strided((2, 8, 1543, 128), (1580032, 197504, 128, 1), device='cuda:4', dtype=torch.bfloat16)
681
+ primals_7 = 13
682
+ primals_8 = 13
683
+ primals_9 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32)
684
+ primals_10 = 1543
685
+ primals_11 = 1543
686
+ primals_12 = 13
687
+ primals_13 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32)
688
+ primals_14 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64)
689
+ primals_15 = 1543
690
+ primals_16 = 13
691
+ primals_17 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32)
692
+ primals_18 = 13
693
+ primals_19 = 13
694
+ primals_20 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32)
695
+ primals_21 = 13
696
+ primals_22 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32)
697
+ primals_23 = 13
698
+ primals_24 = 13
699
+ primals_25 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32)
700
+ primals_26 = 13
701
+ primals_27 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32)
702
+ primals_28 = 13
703
+ primals_29 = 13
704
+ primals_30 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32)
705
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30])
706
+ return print_performance(fn, times=times, repeat=repeat)
707
+
708
+
709
+ if __name__ == "__main__":
710
+ from torch._inductor.wrapper_benchmark import compiled_module_main
711
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/6j/c6jzxztdxbjv5b23nfmgzgtizqp77h7aeak5j2jukmz3roqeiw3k.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 512},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.full([1], 0, tl.int32)
24
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/6j/e9590d30530b6f20cd8332cd18dfb56bc33c5ce0f73ebafd83fbd8da1a7ab8fe.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "42NVHDOVRHC3TSIT2M6NVJU72L5EVVTGFXWS47GDCP2GM2XRN7KA"}
SpecForge-ext/cache/compiled_kernels/6m/c6mwcfy2ykv3p5alrzh4sx4ajhl5davetqobw2pytyc2kalbo2wk.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['10_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mb/cmboruk2gyuhq43degftaqzb2abxergkmetbzmcgprn7eynqywpe.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul_6
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg1_1 : Tensor "bf16[8, s14, 151936][151936*s14, 151936, 1]cuda:7" = PlaceHolder[target=arg1_1]
47
+ # %argmax : Tensor "i64[8, s14][s14, 1]cuda:7" = PlaceHolder[target=argmax]
48
+ # %arg2_1 : Tensor "b8[151936][1]cuda:7" = PlaceHolder[target=arg2_1]
49
+ # %arg3_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:7" = PlaceHolder[target=arg3_1]
50
+ # %argmax : Tensor "i64[8, s14][s14, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[8, s14][s14, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[8, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[8, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul_6 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {})
55
+ # return %argmax,%mul_6
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 16384, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ r0_numel = 151936
75
+ rnumel = r0_numel
76
+ RBLOCK: tl.constexpr = R0_BLOCK
77
+ xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
78
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
79
+ xmask = xindex < xnumel
80
+ r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
81
+ rbase = r0_base
82
+ x0 = xindex
83
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
84
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
85
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
86
+ r0_index = r0_offset + r0_base
87
+ r0_mask = r0_index < r0_numel
88
+ roffset = r0_offset
89
+ rindex = r0_index
90
+ r0_1 = r0_index
91
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
92
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
93
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
94
+ _tmp2, _tmp2_index, tmp1, rindex
95
+ )
96
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
97
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
98
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
99
+ tmp2 = tmp2_idx[:, None]
100
+ tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
101
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
102
+ tmp4 = tmp2 + tmp3
103
+ tmp5 = tmp2 < 0
104
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
105
+ tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936")
106
+ tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1)
107
+ tmp9 = tmp8.to(tl.int32)
108
+ tmp10 = tmp9.to(tl.int64)
109
+ tmp12 = tmp10 * tmp11
110
+ tl.debug_barrier()
111
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
112
+ ''', device_str='cuda')
113
+
114
+
115
+ async_compile.wait(globals())
116
+ del async_compile
117
+
118
+ class Runner:
119
+ def __init__(self, partitions):
120
+ self.partitions = partitions
121
+
122
+ def recursively_apply_fns(self, fns):
123
+ new_callables = []
124
+ for fn, c in zip(fns, self.partitions):
125
+ new_callables.append(fn(c))
126
+ self.partitions = new_callables
127
+
128
+ def call(self, args):
129
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
130
+ args.clear()
131
+ s24 = arg0_1
132
+ arg1_1_size = arg1_1.size()
133
+ s14 = arg1_1_size[1]
134
+ assert_size_stride(arg1_1, (8, s14, 151936), (151936*s14, 151936, 1))
135
+ assert_size_stride(arg2_1, (151936, ), (1, ))
136
+ assert_size_stride(arg3_1, (8, s14, 1), (s14, 1, 1))
137
+ with torch.cuda._DeviceGuard(7):
138
+ torch.cuda.set_device(7)
139
+ buf0 = empty_strided_cuda((8, s14), (s14, 1), torch.int64)
140
+ buf1 = reinterpret_tensor(buf0, (8, s14, 1), (s14, 1, 1), 0); del buf0 # reuse
141
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
142
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 8*s14
143
+ stream7 = get_raw_stream(7)
144
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream7)
145
+ del arg1_1
146
+ del arg2_1
147
+ del arg3_1
148
+ return (buf1, )
149
+
150
+ runner = Runner(partitions=[])
151
+ call = runner.call
152
+ recursively_apply_fns = runner.recursively_apply_fns
153
+
154
+
155
+ def benchmark_compiled_module(times=10, repeat=10):
156
+ from torch._dynamo.testing import rand_strided
157
+ from torch._inductor.utils import print_performance
158
+ arg0_1 = 2025
159
+ arg1_1 = rand_strided((8, 2025, 151936), (307670400, 151936, 1), device='cuda:7', dtype=torch.bfloat16)
160
+ arg2_1 = rand_strided((151936, ), (1, ), device='cuda:7', dtype=torch.bool)
161
+ arg3_1 = rand_strided((8, 2025, 1), (2025, 1, 1), device='cuda:7', dtype=torch.int64)
162
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
163
+ return print_performance(fn, times=times, repeat=repeat)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ from torch._inductor.wrapper_benchmark import compiled_module_main
168
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/6o/c6o7jlqhfbi4ry3uni47hefilsmfptqopfdxwc3plgg65s2mqzse.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gy/cgypquf4bysldt6yik5b24uoywlbfrbaqlpvmqscjgvur4u7ckpi.py
38
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cos => squeeze_1
41
+ # cos_1 => unsqueeze
42
+ # getitem => index
43
+ # getitem_1 => index_1
44
+ # sin => squeeze_3
45
+ # sin_1 => unsqueeze_1
46
+ # squeeze => squeeze
47
+ # squeeze_2 => squeeze_2
48
+ # Graph fragment:
49
+ # %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5" = PlaceHolder[target=tangents_2]
50
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8]
51
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6]
52
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4]
53
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
54
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
55
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
56
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
57
+ # %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {})
58
+ # %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {})
59
+ # %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {})
60
+ # %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {})
61
+ # %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False})
62
+ # %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {})
63
+ # %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {})
64
+ # %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
65
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
66
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
67
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
68
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
69
+ # %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {})
70
+ # %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {})
71
+ # return %add_101
72
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.pointwise(
82
+ size_hints={'x': 4194304},
83
+ filename=__file__,
84
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
85
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
86
+ min_elem_per_thread=0
87
+ )
88
+ @triton.jit
89
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
90
+ xoffset = tl.program_id(0) * XBLOCK
91
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
92
+ xmask = xindex < xnumel
93
+ x0 = (xindex % ks0)
94
+ x3 = xindex
95
+ x1 = ((xindex // ks0) % ks1)
96
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
97
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
98
+ tmp0 = x0
99
+ tmp1 = ks0 // 2
100
+ tmp2 = tmp0 >= tmp1
101
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
102
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
103
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
104
+ tmp6 = tmp4 + tmp5
105
+ tmp7 = tmp4 < 0
106
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
107
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
108
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
109
+ tmp11 = tmp3 * tmp10
110
+ tmp12 = -tmp11
111
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
112
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
113
+ tmp15 = 0.0
114
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
115
+ tmp17 = tmp0 < tmp1
116
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
117
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
118
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
119
+ tmp21 = tmp19 + tmp20
120
+ tmp22 = tmp19 < 0
121
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
122
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
123
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
124
+ tmp26 = tmp18 * tmp25
125
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
126
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
127
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
128
+ tmp30 = tmp16 + tmp29
129
+ tmp33 = ks3
130
+ tmp34 = tmp32 + tmp33
131
+ tmp35 = tmp32 < 0
132
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
133
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
134
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
135
+ tmp39 = tmp31 * tmp38
136
+ tmp40 = tmp30 + tmp39
137
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
138
+ ''', device_str='cuda')
139
+
140
+
141
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bp/cbpuekklhjdszdnpjmnzg77zhi5rum3iueweicitfxwda6abrl2a.py
142
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
143
+ # Source node to ATen node mapping:
144
+ # cos => squeeze_1
145
+ # cos_1 => unsqueeze
146
+ # getitem => index
147
+ # getitem_1 => index_1
148
+ # sin => squeeze_3
149
+ # sin_1 => unsqueeze_1
150
+ # squeeze => squeeze
151
+ # squeeze_2 => squeeze_2
152
+ # Graph fragment:
153
+ # %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5" = PlaceHolder[target=tangents_1]
154
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8]
155
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4]
157
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
158
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
159
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
160
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
161
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
162
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
163
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
164
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
165
+ # %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {})
166
+ # %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {})
167
+ # %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {})
168
+ # %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {})
169
+ # %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False})
170
+ # %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {})
171
+ # %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {})
172
+ # %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {})
173
+ # %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {})
174
+ # %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {})
175
+ # return %add_107
176
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', '''
177
+ import triton
178
+ import triton.language as tl
179
+
180
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
181
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
182
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
183
+ triton_helpers.set_driver_to_gpu()
184
+
185
+ @triton_heuristics.pointwise(
186
+ size_hints={'x': 16777216},
187
+ filename=__file__,
188
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
189
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
190
+ min_elem_per_thread=0
191
+ )
192
+ @triton.jit
193
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
194
+ xoffset = tl.program_id(0) * XBLOCK
195
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
196
+ xmask = xindex < xnumel
197
+ x0 = (xindex % ks0)
198
+ x3 = xindex
199
+ x1 = ((xindex // ks0) % ks1)
200
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
201
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
202
+ tmp0 = x0
203
+ tmp1 = ks0 // 2
204
+ tmp2 = tmp0 >= tmp1
205
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
206
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
207
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
208
+ tmp6 = tmp4 + tmp5
209
+ tmp7 = tmp4 < 0
210
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
211
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
212
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
213
+ tmp11 = tmp3 * tmp10
214
+ tmp12 = -tmp11
215
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
216
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
217
+ tmp15 = 0.0
218
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
219
+ tmp17 = tmp0 < tmp1
220
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
222
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
223
+ tmp21 = tmp19 + tmp20
224
+ tmp22 = tmp19 < 0
225
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
226
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
227
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
228
+ tmp26 = tmp18 * tmp25
229
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
230
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
231
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
232
+ tmp30 = tmp16 + tmp29
233
+ tmp33 = ks3
234
+ tmp34 = tmp32 + tmp33
235
+ tmp35 = tmp32 < 0
236
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
237
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
238
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
239
+ tmp39 = tmp31 * tmp38
240
+ tmp40 = tmp30 + tmp39
241
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
242
+ ''', device_str='cuda')
243
+
244
+
245
+ async_compile.wait(globals())
246
+ del async_compile
247
+
248
+ class Runner:
249
+ def __init__(self, partitions):
250
+ self.partitions = partitions
251
+
252
+ def recursively_apply_fns(self, fns):
253
+ new_callables = []
254
+ for fn, c in zip(fns, self.partitions):
255
+ new_callables.append(fn(c))
256
+ self.partitions = new_callables
257
+
258
+ def call(self, args):
259
+ primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args
260
+ args.clear()
261
+ s24 = primals_2
262
+ s9 = primals_7
263
+ s48 = primals_10
264
+ s34 = primals_11
265
+ s25 = primals_13
266
+ s92 = primals_1
267
+ s96 = primals_3
268
+ s79 = primals_5
269
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
270
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
271
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
272
+ assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1))
273
+ assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1))
274
+ with torch.cuda._DeviceGuard(5):
275
+ torch.cuda.set_device(5)
276
+ buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16)
277
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
278
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9
279
+ stream5 = get_raw_stream(5)
280
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream5)
281
+ del tangents_2
282
+ buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16)
283
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
284
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9
285
+ stream5 = get_raw_stream(5)
286
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream5)
287
+ del primals_4
288
+ del primals_6
289
+ del primals_8
290
+ del tangents_1
291
+ return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, )
292
+
293
+ runner = Runner(partitions=[])
294
+ call = runner.call
295
+ recursively_apply_fns = runner.recursively_apply_fns
296
+
297
+
298
+ def benchmark_compiled_module(times=10, repeat=10):
299
+ from torch._dynamo.testing import rand_strided
300
+ from torch._inductor.utils import print_performance
301
+ primals_2 = 128
302
+ primals_7 = 2048
303
+ primals_10 = 2
304
+ primals_11 = 32
305
+ primals_13 = 8
306
+ primals_1 = 2048
307
+ primals_3 = 5245440
308
+ primals_5 = 2048
309
+ floordiv = 64
310
+ add_96 = 64
311
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16)
312
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16)
313
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:5', dtype=torch.int64)
314
+ tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16)
315
+ tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16)
316
+ fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2])
317
+ return print_performance(fn, times=times, repeat=repeat)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ from torch._inductor.wrapper_benchmark import compiled_module_main
322
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/6r/734ee9f72fcbbc036c304bd9fc428175dc6febf6da61f182679d20ad4d8b7f41.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"}
SpecForge-ext/cache/compiled_kernels/6r/c6r6adrqwwhzfcdd5cyhmwl3cptpvwwhedzdpranw7esxeg5oyia.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/6r/c6rbvgm53jr3nux66durqhisanccgaebzxcdjdhdrphqjpyu2t5r.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
29
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
37
+ tmp1 = tmp0.to(tl.float32)
38
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
39
+
40
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
41
+ _tmp3_max, _tmp3_sum, tmp2, False
42
+ )
43
+
44
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
45
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
46
+
47
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
48
+ _tmp3_max, _tmp3_sum, 1, False)
49
+ tmp3 = tmp3[:, None]
50
+ tmp4 = tmp4[:, None]
51
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
52
+ r0_index = r0_offset + r0_base
53
+ r0_mask = r0_index < r0_numel
54
+ roffset = r0_offset
55
+ rindex = r0_index
56
+ r0_1 = r0_index
57
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp6 = tmp5.to(tl.float32)
59
+ tmp7 = tmp6 - tmp3
60
+ tmp8 = libdevice.exp(tmp7)
61
+ tmp9 = (tmp8 / tmp4)
62
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)