Lekr0 commited on
Commit
a105b1f
·
verified ·
1 Parent(s): 80d5b8b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/cache/compiled_kernels/27/178fbe80a6655cd928415af01c18f906aac88a91f6ad870b6e019e505e41d8d6.best_config +1 -0
  2. SpecForge-ext/cache/compiled_kernels/27/9e94f85bd7b5310932ad7debf393fd211c0f4a83a0599bb42899fad47199226a.best_config +1 -0
  3. SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py +56 -0
  4. SpecForge-ext/cache/compiled_kernels/27/c27s4qoyzyvf54snkgtay3lqlnoj3bgphotvv5xwczxe6bqovure.py +49 -0
  5. SpecForge-ext/cache/compiled_kernels/3j/319ae573ab7247866e4ff0749ebbc205378f8b72492d64376005ae12fdcc85cb.best_config +1 -0
  6. SpecForge-ext/cache/compiled_kernels/3j/c3j47dekusw3y4mohtk5v36cc6fso3wdtqn5oqjwew3yy3exjo76.py +49 -0
  7. SpecForge-ext/cache/compiled_kernels/3r/c3rfwo25yzbkbl5er7svhpxhxxjdbre5zoxeq5wwcwlsvq2puinx.py +322 -0
  8. SpecForge-ext/cache/compiled_kernels/3z/c3zdyaemmekwelpoed5jduslt3o4gp6avp6it2wx3udu2z3kxz65.py +309 -0
  9. SpecForge-ext/cache/compiled_kernels/43/c43xt3pebupiz26noaypbobqj4gw2z5njubsnsy3la7enx2j3exz.py +307 -0
  10. SpecForge-ext/cache/compiled_kernels/46/7bb099c5cd896693d7710d6358df16b56a62f466c1207186e1ec6fa6aaeb5653.best_config +1 -0
  11. SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py +52 -0
  12. SpecForge-ext/cache/compiled_kernels/5d/c5dxklofifxzswxjhdjvko4ncyrk6vkfrbohhy3eg5kffm63zqjg.py +62 -0
  13. SpecForge-ext/cache/compiled_kernels/5i/c5ijpx5gd3tgnruo2ufacghy5ivgjwoj6s4fhr7c2advvuhujqou.py +334 -0
  14. SpecForge-ext/cache/compiled_kernels/5n/c5n5tizmmgs4cmiupzpopubn6t7eviwt42e3csvt472h63vjwmbu.py +835 -0
  15. SpecForge-ext/cache/compiled_kernels/6a/c6akpy3glququp6suktd5kfns5jol46fmjfl5brlavs7c4zodqhi.py +354 -0
  16. SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py +25 -0
  17. SpecForge-ext/cache/compiled_kernels/73/c73fwyrbq77x2xei6vgy66r34rmoo6pilk6kf2iqehy6oksjmbwj.py +693 -0
  18. SpecForge-ext/cache/compiled_kernels/bj/cbj72m23cmcn2yjoxrp4vabc2f76gw727jcpbi4y5oidokqenki5.py +24 -0
  19. SpecForge-ext/cache/compiled_kernels/ce/ccez434a7hzyympuosxgkqmu5zncaqowipase2enwlehm3k7igny.py +72 -0
  20. SpecForge-ext/cache/compiled_kernels/cm/c07273381756209821a51449b0970c31551d44d464333fbb852c9fc655362c46.best_config +1 -0
  21. SpecForge-ext/cache/compiled_kernels/cr/ccr5s7nffy4cqd7a3lcq3cnv2prruzwzc7chchf776jguuqqh5bc.py +66 -0
  22. SpecForge-ext/cache/compiled_kernels/d4/cd4qgy6v3vbg74qytdbsmdpamjzb6kuwcsiu7yfpml4f7zxhf4j3.py +164 -0
  23. SpecForge-ext/cache/compiled_kernels/e4/9676ab333d44c7c3eec122b806e4fc2028468bcd55a94115e7272e322515d58c.best_config +1 -0
  24. SpecForge-ext/cache/compiled_kernels/e6/ce65awkeaxcxjqfa27pcogsy3sjyxwzxjt3w2rte76m7izgybp2s.py +72 -0
  25. SpecForge-ext/cache/compiled_kernels/ea/ceahttlkg35qey3ao6gw65rzzv3bop5xwrthhogt6nyvyw3rece5.py +835 -0
  26. SpecForge-ext/cache/compiled_kernels/eg/32c642d9b91cf1b4fe91f745e14ee86eccc7f5783759ca318976f5af47c474c9.best_config +1 -0
  27. SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py +66 -0
  28. SpecForge-ext/cache/compiled_kernels/eg/cegtc7d6fywtdtf2rerfuwwwn7fajohhh2ltqlvjjvdguxabbva4.py +416 -0
  29. SpecForge-ext/cache/compiled_kernels/el/celb4xosanyf3m2sx6v3t54w4bgkx65m4lb2newu7nggikw6jbxj.py +52 -0
  30. SpecForge-ext/cache/compiled_kernels/et/cet6lrlwcthdi3by3ttnab2z245l4q55x7tvdilkic6xqjfjlixg.py +62 -0
  31. SpecForge-ext/cache/compiled_kernels/ey/cey3ar6s7f2t62buescu5cctxdhf6hmbv3ps5d3tmh235oaj3fj6.py +56 -0
  32. SpecForge-ext/cache/compiled_kernels/ey/ceyifglcwq5k7zog6faauufd7zk5fsacgjqk43m6vpya73dy3l62.py +543 -0
  33. SpecForge-ext/cache/compiled_kernels/ey/ceyzf3pcewvjtqjk6jiokovxh2sqktcak7dttp7wu3pugjxaoweu.py +835 -0
  34. SpecForge-ext/cache/compiled_kernels/ey/f8e6f482f3185b2937177b6d0b6caa60104c3cdb0966b9b98cfda24132197a8c.best_config +1 -0
  35. SpecForge-ext/cache/compiled_kernels/fg/cfg7sytfzjcof3mvqa6lexwoxlaj3zogf2jn2jbgerew6ytuhqkm.py +37 -0
  36. SpecForge-ext/cache/compiled_kernels/fg/cfgdj37atk5pvqz7oags4dv3jc65exjssmxxu3c4srgtfjnh7kgw.py +552 -0
  37. SpecForge-ext/cache/compiled_kernels/fg/cfgilsqr4dj7cpcripi7zlobhu3rqxlfddiwwrzuy5xlumnjw5lh.py +37 -0
  38. SpecForge-ext/cache/compiled_kernels/fo/cfooe7ht55q5jhejzd3zyb3g5v64cvxjohkxeadllgnjxgiwo52v.py +26 -0
  39. SpecForge-ext/cache/compiled_kernels/ft/cftsee2mvtzxgy2wgchwunv4g4rgysco4n3gsokqlal6zoqbmnub.py +303 -0
  40. SpecForge-ext/cache/compiled_kernels/g3/cg3kczutozttzr55b4vjq62nto7vv2qnqb553mhae4gtgepz7vkj.py +89 -0
  41. SpecForge-ext/cache/compiled_kernels/hq/chqc3is7lze3bdohf7qrowyfetyhjquhgfsobrnoq7hbrmp6ohdx.py +334 -0
  42. SpecForge-ext/cache/compiled_kernels/hq/chqstdcrwlggtj2cbkjjgtxib54f5qfcipeqs3k27hifudgguv7t.py +835 -0
  43. SpecForge-ext/cache/compiled_kernels/hv/chvj5h3adlnuxifatrhlirixthstwv5pzbxvuapjby5cz2npck63.py +99 -0
  44. SpecForge-ext/cache/compiled_kernels/ks/36cfdc5c4318d8e35940f3471fa9a8cde8092c3294a90679819920b4db6ea3bb.best_config +1 -0
  45. SpecForge-ext/cache/compiled_kernels/ks/cksdatp7sjl5kfr5pxvwrbjelhvz35c35rvym5wgbvhrovwd5isa.py +62 -0
  46. SpecForge-ext/cache/compiled_kernels/ks/ckske6cm4vgoewu6hpzmhdk7yxnddtnqlrbts7nwodsrty3grim2.py +25 -0
  47. SpecForge-ext/cache/compiled_kernels/kx/ckxgrh6l45wgzd3gv6uy3i3z4hrfyct6es6sh2fdnsi6q4hicyjs.py +168 -0
  48. SpecForge-ext/cache/compiled_kernels/kx/ckxtdzhg3azhdxeooy2uushwzka4sz2hzjpq5dulk2g2jjweqr6b.py +552 -0
  49. SpecForge-ext/cache/compiled_kernels/l4/858a6c2e50b765fa4386efe0007977eb588741281d2c492d383f481ceaa46b11.best_config +1 -0
  50. SpecForge-ext/cache/compiled_kernels/l4/cl45ilp34erze7maypgnzjiaafh3lmzk67erw2irtjg7fhwhyggv.py +835 -0
SpecForge-ext/cache/compiled_kernels/27/178fbe80a6655cd928415af01c18f906aac88a91f6ad870b6e019e505e41d8d6.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 140, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"}
SpecForge-ext/cache/compiled_kernels/27/9e94f85bd7b5310932ad7debf393fd211c0f4a83a0599bb42899fad47199226a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"}
SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/27/c27s4qoyzyvf54snkgtay3lqlnoj3bgphotvv5xwczxe6bqovure.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 524288, 'r0_': 128},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 128
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = ((xindex // ks0) % 32)
29
+ x2 = xindex // ks1
30
+ x5 = triton_helpers.div_floor_integer(xindex, ks0)
31
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
32
+ x4 = xindex
33
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_3 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
40
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
41
+ tmp2 = tmp0 * tmp1
42
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
43
+ tmp5 = _tmp4 + tmp3
44
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
45
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
46
+ tmp6 = tmp4.to(tl.float32)
47
+ tmp7 = 0.0
48
+ tmp8 = tmp6 - tmp7
49
+ tl.store(out_ptr1 + (x4), tmp8, xmask)
SpecForge-ext/cache/compiled_kernels/3j/319ae573ab7247866e4ff0749ebbc205378f8b72492d64376005ae12fdcc85cb.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"}
SpecForge-ext/cache/compiled_kernels/3j/c3j47dekusw3y4mohtk5v36cc6fso3wdtqn5oqjwew3yy3exjo76.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 64, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 16
20
+ R0_BLOCK: tl.constexpr = 16
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = xindex // ks0
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0)
36
+ tmp1 = r0_2
37
+ tmp2 = tmp1.to(tl.int16)
38
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
41
+ tmp7 = tmp0.to(tl.int64)
42
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
43
+ tmp10 = tl.where(xmask, tmp8, 0)
44
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
45
+ tmp12 = tmp6.to(tl.int64)
46
+ tmp13 = tmp12.to(tl.int32)
47
+ tmp14 = tmp11.to(tl.int32)
48
+ tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask)
49
+ tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/3r/c3rfwo25yzbkbl5er7svhpxhxxjdbre5zoxeq5wwcwlsvq2puinx.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/df/cdfb6cgenzsju5cqvy4244xh4xidniyeznvkubvdg2mg6d5oc6xt.py
38
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cos => squeeze_1
41
+ # cos_1 => unsqueeze
42
+ # getitem => index
43
+ # getitem_1 => index_1
44
+ # sin => squeeze_3
45
+ # sin_1 => unsqueeze_1
46
+ # squeeze => squeeze
47
+ # squeeze_2 => squeeze_2
48
+ # Graph fragment:
49
+ # %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7" = PlaceHolder[target=tangents_2]
50
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8]
51
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6]
52
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4]
53
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
54
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
55
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
56
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
57
+ # %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {})
58
+ # %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {})
59
+ # %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {})
60
+ # %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {})
61
+ # %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:7, pin_memory: False})
62
+ # %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {})
63
+ # %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {})
64
+ # %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
65
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
66
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
67
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
68
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
69
+ # %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {})
70
+ # %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {})
71
+ # return %add_101
72
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.pointwise(
82
+ size_hints={'x': 4194304},
83
+ filename=__file__,
84
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
85
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
86
+ min_elem_per_thread=0
87
+ )
88
+ @triton.jit
89
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
90
+ xoffset = tl.program_id(0) * XBLOCK
91
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
92
+ xmask = xindex < xnumel
93
+ x0 = (xindex % ks0)
94
+ x3 = xindex
95
+ x1 = ((xindex // ks0) % ks1)
96
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
97
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
98
+ tmp0 = x0
99
+ tmp1 = ks0 // 2
100
+ tmp2 = tmp0 >= tmp1
101
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
102
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
103
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
104
+ tmp6 = tmp4 + tmp5
105
+ tmp7 = tmp4 < 0
106
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
107
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
108
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
109
+ tmp11 = tmp3 * tmp10
110
+ tmp12 = -tmp11
111
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
112
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
113
+ tmp15 = 0.0
114
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
115
+ tmp17 = tmp0 < tmp1
116
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
117
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
118
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
119
+ tmp21 = tmp19 + tmp20
120
+ tmp22 = tmp19 < 0
121
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
122
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
123
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
124
+ tmp26 = tmp18 * tmp25
125
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
126
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
127
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
128
+ tmp30 = tmp16 + tmp29
129
+ tmp33 = ks3
130
+ tmp34 = tmp32 + tmp33
131
+ tmp35 = tmp32 < 0
132
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
133
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
134
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
135
+ tmp39 = tmp31 * tmp38
136
+ tmp40 = tmp30 + tmp39
137
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
138
+ ''', device_str='cuda')
139
+
140
+
141
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gg/cgggk6pegregqt4lolln3yxfp6wzahy6vf2ocae3vbpohfif7mtz.py
142
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
143
+ # Source node to ATen node mapping:
144
+ # cos => squeeze_1
145
+ # cos_1 => unsqueeze
146
+ # getitem => index
147
+ # getitem_1 => index_1
148
+ # sin => squeeze_3
149
+ # sin_1 => unsqueeze_1
150
+ # squeeze => squeeze
151
+ # squeeze_2 => squeeze_2
152
+ # Graph fragment:
153
+ # %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7" = PlaceHolder[target=tangents_1]
154
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8]
155
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4]
157
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
158
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
159
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
160
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
161
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
162
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
163
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
164
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
165
+ # %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {})
166
+ # %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {})
167
+ # %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {})
168
+ # %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {})
169
+ # %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:7, pin_memory: False})
170
+ # %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {})
171
+ # %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {})
172
+ # %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {})
173
+ # %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {})
174
+ # %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {})
175
+ # return %add_107
176
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', '''
177
+ import triton
178
+ import triton.language as tl
179
+
180
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
181
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
182
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
183
+ triton_helpers.set_driver_to_gpu()
184
+
185
+ @triton_heuristics.pointwise(
186
+ size_hints={'x': 16777216},
187
+ filename=__file__,
188
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
189
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
190
+ min_elem_per_thread=0
191
+ )
192
+ @triton.jit
193
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
194
+ xoffset = tl.program_id(0) * XBLOCK
195
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
196
+ xmask = xindex < xnumel
197
+ x0 = (xindex % ks0)
198
+ x3 = xindex
199
+ x1 = ((xindex // ks0) % ks1)
200
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
201
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
202
+ tmp0 = x0
203
+ tmp1 = ks0 // 2
204
+ tmp2 = tmp0 >= tmp1
205
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
206
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
207
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
208
+ tmp6 = tmp4 + tmp5
209
+ tmp7 = tmp4 < 0
210
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
211
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
212
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
213
+ tmp11 = tmp3 * tmp10
214
+ tmp12 = -tmp11
215
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
216
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
217
+ tmp15 = 0.0
218
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
219
+ tmp17 = tmp0 < tmp1
220
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
222
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
223
+ tmp21 = tmp19 + tmp20
224
+ tmp22 = tmp19 < 0
225
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
226
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
227
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
228
+ tmp26 = tmp18 * tmp25
229
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
230
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
231
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
232
+ tmp30 = tmp16 + tmp29
233
+ tmp33 = ks3
234
+ tmp34 = tmp32 + tmp33
235
+ tmp35 = tmp32 < 0
236
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
237
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
238
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
239
+ tmp39 = tmp31 * tmp38
240
+ tmp40 = tmp30 + tmp39
241
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
242
+ ''', device_str='cuda')
243
+
244
+
245
+ async_compile.wait(globals())
246
+ del async_compile
247
+
248
+ class Runner:
249
+ def __init__(self, partitions):
250
+ self.partitions = partitions
251
+
252
+ def recursively_apply_fns(self, fns):
253
+ new_callables = []
254
+ for fn, c in zip(fns, self.partitions):
255
+ new_callables.append(fn(c))
256
+ self.partitions = new_callables
257
+
258
+ def call(self, args):
259
+ primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args
260
+ args.clear()
261
+ s24 = primals_2
262
+ s9 = primals_7
263
+ s48 = primals_10
264
+ s34 = primals_11
265
+ s25 = primals_13
266
+ s92 = primals_1
267
+ s96 = primals_3
268
+ s79 = primals_5
269
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
270
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
271
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
272
+ assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1))
273
+ assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1))
274
+ with torch.cuda._DeviceGuard(7):
275
+ torch.cuda.set_device(7)
276
+ buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16)
277
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
278
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9
279
+ stream7 = get_raw_stream(7)
280
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream7)
281
+ del tangents_2
282
+ buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16)
283
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
284
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9
285
+ stream7 = get_raw_stream(7)
286
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream7)
287
+ del primals_4
288
+ del primals_6
289
+ del primals_8
290
+ del tangents_1
291
+ return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, )
292
+
293
+ runner = Runner(partitions=[])
294
+ call = runner.call
295
+ recursively_apply_fns = runner.recursively_apply_fns
296
+
297
+
298
+ def benchmark_compiled_module(times=10, repeat=10):
299
+ from torch._dynamo.testing import rand_strided
300
+ from torch._inductor.utils import print_performance
301
+ primals_2 = 128
302
+ primals_7 = 2048
303
+ primals_10 = 2
304
+ primals_11 = 32
305
+ primals_13 = 8
306
+ primals_1 = 2048
307
+ primals_3 = 5245440
308
+ primals_5 = 2048
309
+ floordiv = 64
310
+ add_96 = 64
311
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16)
312
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16)
313
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:7', dtype=torch.int64)
314
+ tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16)
315
+ tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16)
316
+ fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2])
317
+ return print_performance(fn, times=times, repeat=repeat)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ from torch._inductor.wrapper_benchmark import compiled_module_main
322
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/3z/c3zdyaemmekwelpoed5jduslt3o4gp6avp6it2wx3udu2z3kxz65.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rv/crvxbztkr372fgdn7bgrud22s3wmd2isidwo4ek4hldn56tuv2dj.py
38
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cat => cat
41
+ # cos => squeeze_1
42
+ # cos_1 => unsqueeze
43
+ # getitem => index
44
+ # getitem_1 => index_1
45
+ # mul => mul_24
46
+ # mul_1 => mul_45
47
+ # neg => neg
48
+ # q_embed => add_54
49
+ # sin => squeeze_3
50
+ # sin_1 => unsqueeze_1
51
+ # squeeze => squeeze
52
+ # squeeze_2 => squeeze_2
53
+ # x1 => slice_1
54
+ # x2 => slice_2
55
+ # Graph fragment:
56
+ # %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1" = PlaceHolder[target=primals_12]
57
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8]
58
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4]
59
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6]
60
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
61
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
62
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
63
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
64
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
65
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
66
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
67
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
68
+ # %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {})
69
+ # %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {})
70
+ # %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {})
71
+ # %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
72
+ # %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
73
+ # %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {})
74
+ # %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {})
75
+ # return %add_54
76
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', '''
77
+ import triton
78
+ import triton.language as tl
79
+
80
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
81
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
82
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
83
+ triton_helpers.set_driver_to_gpu()
84
+
85
+ @triton_heuristics.pointwise(
86
+ size_hints={'x': 16777216},
87
+ filename=__file__,
88
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
89
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
90
+ min_elem_per_thread=0
91
+ )
92
+ @triton.jit
93
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
94
+ xoffset = tl.program_id(0) * XBLOCK
95
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
96
+ xmask = xindex < xnumel
97
+ x4 = xindex
98
+ x2 = ((xindex // ks0) % ks1)
99
+ x0 = (xindex % ks3)
100
+ x5 = xindex // ks3
101
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
102
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
103
+ tmp2 = ks2
104
+ tmp3 = tmp1 + tmp2
105
+ tmp4 = tmp1 < 0
106
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
107
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
108
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
109
+ tmp8 = tmp0 * tmp7
110
+ tmp9 = x0
111
+ tmp10 = tl.full([1], 0, tl.int64)
112
+ tmp11 = tmp9 >= tmp10
113
+ tmp12 = ks3 + (-1)*(ks3 // 2)
114
+ tmp13 = tmp9 < tmp12
115
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
116
+ tmp15 = -tmp14
117
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
118
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
119
+ tmp18 = tmp9 >= tmp12
120
+ tmp19 = ks3
121
+ tmp20 = tmp9 < tmp19
122
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
123
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
124
+ tmp23 = ks4
125
+ tmp24 = tmp1 + tmp23
126
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
127
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
128
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
129
+ tmp28 = tmp22 * tmp27
130
+ tmp29 = tmp8 + tmp28
131
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
132
+ ''', device_str='cuda')
133
+
134
+
135
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kt/ckt6ylj5dotkksawawvin5yyeytmo5tcvmqpulhfstvqh3aecfft.py
136
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
137
+ # Source node to ATen node mapping:
138
+ # cat_1 => cat_1
139
+ # cos => squeeze_1
140
+ # cos_1 => unsqueeze
141
+ # getitem => index
142
+ # getitem_1 => index_1
143
+ # k_embed => add_90
144
+ # mul_2 => mul_54
145
+ # mul_3 => mul_75
146
+ # neg_1 => neg_1
147
+ # sin => squeeze_3
148
+ # sin_1 => unsqueeze_1
149
+ # squeeze => squeeze
150
+ # squeeze_2 => squeeze_2
151
+ # x1_1 => slice_3
152
+ # x2_1 => slice_4
153
+ # Graph fragment:
154
+ # %primals_14 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:1" = PlaceHolder[target=primals_14]
155
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4]
157
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6]
158
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
159
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
160
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
161
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
162
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
163
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
164
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
165
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
166
+ # %mul_54 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_14, %unsqueeze), kwargs = {})
167
+ # %slice_3 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, 0, %floordiv), kwargs = {})
168
+ # %slice_4 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, %floordiv, 9223372036854775807), kwargs = {})
169
+ # %neg_1 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s25*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
170
+ # %cat_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
171
+ # %mul_75 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {})
172
+ # %add_90 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {})
173
+ # return %add_90
174
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', '''
175
+ import triton
176
+ import triton.language as tl
177
+
178
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
179
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
180
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
181
+ triton_helpers.set_driver_to_gpu()
182
+
183
+ @triton_heuristics.pointwise(
184
+ size_hints={'x': 4194304},
185
+ filename=__file__,
186
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
187
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
188
+ min_elem_per_thread=0
189
+ )
190
+ @triton.jit
191
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
192
+ xoffset = tl.program_id(0) * XBLOCK
193
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
194
+ xmask = xindex < xnumel
195
+ x4 = xindex
196
+ x2 = ((xindex // ks0) % ks1)
197
+ x0 = (xindex % ks3)
198
+ x5 = xindex // ks3
199
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
200
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
201
+ tmp2 = ks2
202
+ tmp3 = tmp1 + tmp2
203
+ tmp4 = tmp1 < 0
204
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
205
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
206
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
207
+ tmp8 = tmp0 * tmp7
208
+ tmp9 = x0
209
+ tmp10 = tl.full([1], 0, tl.int64)
210
+ tmp11 = tmp9 >= tmp10
211
+ tmp12 = ks3 + (-1)*(ks3 // 2)
212
+ tmp13 = tmp9 < tmp12
213
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
214
+ tmp15 = -tmp14
215
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
216
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
217
+ tmp18 = tmp9 >= tmp12
218
+ tmp19 = ks3
219
+ tmp20 = tmp9 < tmp19
220
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
222
+ tmp23 = ks4
223
+ tmp24 = tmp1 + tmp23
224
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
225
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
226
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
227
+ tmp28 = tmp22 * tmp27
228
+ tmp29 = tmp8 + tmp28
229
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
230
+ ''', device_str='cuda')
231
+
232
+
233
+ async_compile.wait(globals())
234
+ del async_compile
235
+
236
+ class Runner:
237
+ def __init__(self, partitions):
238
+ self.partitions = partitions
239
+
240
+ def recursively_apply_fns(self, fns):
241
+ new_callables = []
242
+ for fn, c in zip(fns, self.partitions):
243
+ new_callables.append(fn(c))
244
+ self.partitions = new_callables
245
+
246
+ def call(self, args):
247
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14 = args
248
+ args.clear()
249
+ s92 = primals_1
250
+ s24 = primals_2
251
+ s96 = primals_3
252
+ s79 = primals_5
253
+ s9 = primals_7
254
+ s38 = primals_9
255
+ s48 = primals_10
256
+ s34 = primals_11
257
+ s25 = primals_13
258
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
259
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
260
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
261
+ assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1))
262
+ assert_size_stride(primals_14, (s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1))
263
+ with torch.cuda._DeviceGuard(1):
264
+ torch.cuda.set_device(1)
265
+ ps0 = s24*s34
266
+ buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16)
267
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
268
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9
269
+ stream1 = get_raw_stream(1)
270
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream1)
271
+ del primals_12
272
+ ps1 = s24*s25
273
+ buf1 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1), torch.bfloat16)
274
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
275
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s25*s48*s9
276
+ stream1 = get_raw_stream(1)
277
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_14, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream1)
278
+ del primals_14
279
+ return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s25, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), )
280
+
281
+ runner = Runner(partitions=[])
282
+ call = runner.call
283
+ recursively_apply_fns = runner.recursively_apply_fns
284
+
285
+
286
+ def benchmark_compiled_module(times=10, repeat=10):
287
+ from torch._dynamo.testing import rand_strided
288
+ from torch._inductor.utils import print_performance
289
+ primals_1 = 2048
290
+ primals_2 = 128
291
+ primals_3 = 5245440
292
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16)
293
+ primals_5 = 2048
294
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16)
295
+ primals_7 = 2048
296
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:1', dtype=torch.int64)
297
+ primals_9 = 1
298
+ primals_10 = 2
299
+ primals_11 = 32
300
+ primals_12 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
301
+ primals_13 = 8
302
+ primals_14 = rand_strided((2, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
303
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14])
304
+ return print_performance(fn, times=times, repeat=repeat)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ from torch._inductor.wrapper_benchmark import compiled_module_main
309
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/43/c43xt3pebupiz26noaypbobqj4gw2z5njubsnsy3la7enx2j3exz.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py
38
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cat => cat
41
+ # cos => squeeze_1
42
+ # cos_1 => unsqueeze
43
+ # getitem => index
44
+ # getitem_1 => index_1
45
+ # mul => mul_24
46
+ # mul_1 => mul_45
47
+ # neg => neg
48
+ # q_embed => add_54
49
+ # sin => squeeze_3
50
+ # sin_1 => unsqueeze_1
51
+ # squeeze => squeeze
52
+ # squeeze_2 => squeeze_2
53
+ # x1 => slice_1
54
+ # x2 => slice_2
55
+ # Graph fragment:
56
+ # %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:4" = PlaceHolder[target=primals_12]
57
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:4" = PlaceHolder[target=primals_8]
58
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_4]
59
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_6]
60
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
61
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
62
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
63
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
64
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
65
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
66
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
67
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
68
+ # %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {})
69
+ # %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {})
70
+ # %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {})
71
+ # %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
72
+ # %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
73
+ # %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {})
74
+ # %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {})
75
+ # return %add_54
76
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', '''
77
+ import triton
78
+ import triton.language as tl
79
+
80
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
81
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
82
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
83
+ triton_helpers.set_driver_to_gpu()
84
+
85
+ @triton_heuristics.pointwise(
86
+ size_hints={'x': 67108864},
87
+ filename=__file__,
88
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
89
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
90
+ min_elem_per_thread=0
91
+ )
92
+ @triton.jit
93
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
94
+ xoffset = tl.program_id(0) * XBLOCK
95
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
96
+ xmask = xindex < xnumel
97
+ x4 = xindex
98
+ x2 = ((xindex // ks0) % ks1)
99
+ x0 = (xindex % ks3)
100
+ x5 = xindex // ks3
101
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
102
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
103
+ tmp2 = ks2
104
+ tmp3 = tmp1 + tmp2
105
+ tmp4 = tmp1 < 0
106
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
107
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
108
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
109
+ tmp8 = tmp0 * tmp7
110
+ tmp9 = x0
111
+ tmp10 = tl.full([1], 0, tl.int64)
112
+ tmp11 = tmp9 >= tmp10
113
+ tmp12 = ks3 + (-1)*(ks3 // 2)
114
+ tmp13 = tmp9 < tmp12
115
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
116
+ tmp15 = -tmp14
117
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
118
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
119
+ tmp18 = tmp9 >= tmp12
120
+ tmp19 = ks3
121
+ tmp20 = tmp9 < tmp19
122
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
123
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
124
+ tmp23 = ks4
125
+ tmp24 = tmp1 + tmp23
126
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
127
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
128
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
129
+ tmp28 = tmp22 * tmp27
130
+ tmp29 = tmp8 + tmp28
131
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
132
+ ''', device_str='cuda')
133
+
134
+
135
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xg/cxgourummzwsux6r2gxe7ifvqpdhpgvgbs36tkitfwpr24b4gcvt.py
136
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
137
+ # Source node to ATen node mapping:
138
+ # cat_1 => cat_1
139
+ # cos => squeeze_1
140
+ # cos_1 => unsqueeze
141
+ # getitem => index
142
+ # getitem_1 => index_1
143
+ # k_embed => add_90
144
+ # mul_2 => mul_54
145
+ # mul_3 => mul_75
146
+ # neg_1 => neg_1
147
+ # sin => squeeze_3
148
+ # sin_1 => unsqueeze_1
149
+ # squeeze => squeeze
150
+ # squeeze_2 => squeeze_2
151
+ # x1_1 => slice_3
152
+ # x2_1 => slice_4
153
+ # Graph fragment:
154
+ # %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:4" = PlaceHolder[target=primals_13]
155
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:4" = PlaceHolder[target=primals_8]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_4]
157
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_6]
158
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
159
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
160
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
161
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
162
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
163
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
164
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
165
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
166
+ # %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {})
167
+ # %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {})
168
+ # %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {})
169
+ # %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
170
+ # %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
171
+ # %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {})
172
+ # %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {})
173
+ # return %add_90
174
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', '''
175
+ import triton
176
+ import triton.language as tl
177
+
178
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
179
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
180
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
181
+ triton_helpers.set_driver_to_gpu()
182
+
183
+ @triton_heuristics.pointwise(
184
+ size_hints={'x': 16777216},
185
+ filename=__file__,
186
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
187
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
188
+ min_elem_per_thread=0
189
+ )
190
+ @triton.jit
191
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
192
+ xoffset = tl.program_id(0) * XBLOCK
193
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
194
+ xmask = xindex < xnumel
195
+ x4 = xindex
196
+ x2 = ((xindex // ks0) % ks1)
197
+ x0 = (xindex % ks3)
198
+ x5 = xindex // ks3
199
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
200
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
201
+ tmp2 = ks2
202
+ tmp3 = tmp1 + tmp2
203
+ tmp4 = tmp1 < 0
204
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
205
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
206
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
207
+ tmp8 = tmp0 * tmp7
208
+ tmp9 = x0
209
+ tmp10 = tl.full([1], 0, tl.int64)
210
+ tmp11 = tmp9 >= tmp10
211
+ tmp12 = ks3 + (-1)*(ks3 // 2)
212
+ tmp13 = tmp9 < tmp12
213
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
214
+ tmp15 = -tmp14
215
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
216
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
217
+ tmp18 = tmp9 >= tmp12
218
+ tmp19 = ks3
219
+ tmp20 = tmp9 < tmp19
220
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
222
+ tmp23 = ks4
223
+ tmp24 = tmp1 + tmp23
224
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
225
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
226
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
227
+ tmp28 = tmp22 * tmp27
228
+ tmp29 = tmp8 + tmp28
229
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
230
+ ''', device_str='cuda')
231
+
232
+
233
+ async_compile.wait(globals())
234
+ del async_compile
235
+
236
+ class Runner:
237
+ def __init__(self, partitions):
238
+ self.partitions = partitions
239
+
240
+ def recursively_apply_fns(self, fns):
241
+ new_callables = []
242
+ for fn, c in zip(fns, self.partitions):
243
+ new_callables.append(fn(c))
244
+ self.partitions = new_callables
245
+
246
+ def call(self, args):
247
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args
248
+ args.clear()
249
+ s92 = primals_1
250
+ s24 = primals_2
251
+ s96 = primals_3
252
+ s79 = primals_5
253
+ s9 = primals_7
254
+ s38 = primals_9
255
+ s48 = primals_10
256
+ s34 = primals_11
257
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
258
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
259
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
260
+ assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1))
261
+ assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1))
262
+ with torch.cuda._DeviceGuard(4):
263
+ torch.cuda.set_device(4)
264
+ ps0 = s24*s34
265
+ buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16)
266
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
267
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9
268
+ stream4 = get_raw_stream(4)
269
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream4)
270
+ del primals_12
271
+ ps1 = s24*s48
272
+ buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16)
273
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
274
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48
275
+ stream4 = get_raw_stream(4)
276
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream4)
277
+ del primals_13
278
+ return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), )
279
+
280
+ runner = Runner(partitions=[])
281
+ call = runner.call
282
+ recursively_apply_fns = runner.recursively_apply_fns
283
+
284
+
285
+ def benchmark_compiled_module(times=10, repeat=10):
286
+ from torch._dynamo.testing import rand_strided
287
+ from torch._inductor.utils import print_performance
288
+ primals_1 = 2048
289
+ primals_2 = 128
290
+ primals_3 = 5245440
291
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:4', dtype=torch.bfloat16)
292
+ primals_5 = 2048
293
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:4', dtype=torch.bfloat16)
294
+ primals_7 = 2048
295
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:4', dtype=torch.int64)
296
+ primals_9 = 1
297
+ primals_10 = 8
298
+ primals_11 = 32
299
+ primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
300
+ primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16)
301
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13])
302
+ return print_performance(fn, times=times, repeat=repeat)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ from torch._inductor.wrapper_benchmark import compiled_module_main
307
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/46/7bb099c5cd896693d7710d6358df16b56a62f466c1207186e1ec6fa6aaeb5653.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 63, "triton_cache_hash": "C3FCZCDEMCLSFODWXLEH5MRAQRWLOTRP4SAQURVAE7BPHZSTV2WQ"}
SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x1 = xindex // ks0
27
+ x0 = (xindex % ks0)
28
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
29
+ x3 = xindex
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_2 = r0_index
36
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
37
+ tmp1 = ks1*ks2
38
+ tmp2 = tmp0 < tmp1
39
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp5 = tmp4.to(tl.float32)
42
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
43
+ tmp7 = tmp5 * tmp6
44
+ tmp8 = tmp7.to(tl.float32)
45
+ tmp9 = tmp3 * tmp8
46
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
47
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
48
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
49
+ tmp14 = _tmp13 + tmp12
50
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
51
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
52
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
SpecForge-ext/cache/compiled_kernels/5d/c5dxklofifxzswxjhdjvko4ncyrk6vkfrbohhy3eg5kffm63zqjg.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ r0_numel = 4096
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_0 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
38
+ tmp2 = tmp0 == tmp1
39
+ tmp3 = tmp2.to(tl.int64)
40
+ tmp5 = tmp3 * tmp4
41
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
42
+ tmp8 = _tmp7 + tmp6
43
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
44
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
45
+ _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
46
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
47
+ r0_index = r0_offset + r0_base
48
+ r0_mask = r0_index < r0_numel
49
+ roffset = r0_offset
50
+ rindex = r0_index
51
+ r0_0 = r0_index
52
+ tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
53
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
54
+ tmp12 = _tmp11 + tmp10
55
+ _tmp11 = tl.where(r0_mask, tmp12, _tmp11)
56
+ tmp11 = tl.sum(_tmp11, 1)[:, None]
57
+ tmp13 = tmp7.to(tl.float32)
58
+ tmp14 = tmp11.to(tl.float32)
59
+ tmp15 = 1e-06
60
+ tmp16 = triton_helpers.maximum(tmp14, tmp15)
61
+ tmp17 = (tmp13 / tmp16)
62
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None)
SpecForge-ext/cache/compiled_kernels/5i/c5ijpx5gd3tgnruo2ufacghy5ivgjwoj6s4fhr7c2advvuhujqou.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['2_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py
38
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # hidden_states => convert_element_type
41
+ # hidden_states_1 => mul_16
42
+ # to_1 => convert_element_type_1
43
+ # Graph fragment:
44
+ # %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=tangents_1]
45
+ # %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=primals_4]
46
+ # %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1" = PlaceHolder[target=rsqrt]
47
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
48
+ # %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
49
+ # %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
50
+ # %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
51
+ # %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {})
52
+ # return %buf0
53
+ triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+ triton_helpers.set_driver_to_gpu()
61
+
62
+ @triton_heuristics.reduction(
63
+ size_hints={'x': 131072, 'r0_': 128},
64
+ reduction_hint=ReductionHint.OUTER,
65
+ filename=__file__,
66
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
67
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
68
+ )
69
+ @triton.jit
70
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x1 = xindex // ks0
79
+ x0 = (xindex % ks0)
80
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
81
+ x3 = xindex
82
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
83
+ r0_index = r0_offset + r0_base
84
+ r0_mask = r0_index < r0_numel
85
+ roffset = r0_offset
86
+ rindex = r0_index
87
+ r0_2 = r0_index
88
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
89
+ tmp1 = ks1*ks2
90
+ tmp2 = tmp0 < tmp1
91
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
92
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
93
+ tmp5 = tmp4.to(tl.float32)
94
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
95
+ tmp7 = tmp5 * tmp6
96
+ tmp8 = tmp7.to(tl.float32)
97
+ tmp9 = tmp3 * tmp8
98
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
99
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
100
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
101
+ tmp14 = _tmp13 + tmp12
102
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
103
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
104
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
105
+ ''', device_str='cuda')
106
+
107
+
108
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ug/cug6unv7ylx7cgtwxj6q5dppff2io2k4qf3fhtoe6a2mcfi5dzu5.py
109
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
110
+ # Source node to ATen node mapping:
111
+ # hidden_states => convert_element_type
112
+ # hidden_states_1 => mul_16
113
+ # to_1 => convert_element_type_1
114
+ # Graph fragment:
115
+ # %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:1" = PlaceHolder[target=buf0]
116
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
117
+ # %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
118
+ # %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
119
+ # %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
120
+ # %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {})
121
+ # return %sum_1
122
+ triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', '''
123
+ import triton
124
+ import triton.language as tl
125
+
126
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
127
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
128
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
129
+ triton_helpers.set_driver_to_gpu()
130
+
131
+ @triton_heuristics.persistent_reduction(
132
+ size_hints={'x': 4096, 'r0_': 32},
133
+ reduction_hint=ReductionHint.OUTER,
134
+ filename=__file__,
135
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
137
+ )
138
+ @triton.jit
139
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
140
+ r0_numel = 32
141
+ R0_BLOCK: tl.constexpr = 32
142
+ rnumel = r0_numel
143
+ RBLOCK: tl.constexpr = R0_BLOCK
144
+ xoffset = tl.program_id(0) * XBLOCK
145
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
146
+ xmask = xindex < xnumel
147
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
148
+ r0_offset = 0
149
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
150
+ roffset = r0_offset
151
+ rindex = r0_index
152
+ r0_1 = r0_index
153
+ x0 = xindex
154
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
155
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
156
+ tmp3 = tl.where(xmask, tmp1, 0)
157
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
158
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
159
+ ''', device_str='cuda')
160
+
161
+
162
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sg/csg2r6gcuw5453tnkx7v65zysasesetlrx733ekbslnhgjntjrkm.py
163
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add]
164
+ # Source node to ATen node mapping:
165
+ # hidden_states => convert_element_type
166
+ # Graph fragment:
167
+ # %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=tangents_1]
168
+ # %primals_7 : Tensor "bf16[s33][1]cuda:1" = PlaceHolder[target=primals_7]
169
+ # %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=primals_4]
170
+ # %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1" = PlaceHolder[target=rsqrt]
171
+ # %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:1" = PlaceHolder[target=sum_2]
172
+ # %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {})
173
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
174
+ # %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {})
175
+ # %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {})
176
+ # %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {})
177
+ # %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {})
178
+ # %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {})
179
+ # %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {})
180
+ # %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {})
181
+ # %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {})
182
+ # %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {})
183
+ # %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {})
184
+ # %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {})
185
+ # %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {})
186
+ # %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {})
187
+ # %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {})
188
+ # return %sum_2,%convert_element_type_3
189
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', '''
190
+ import triton
191
+ import triton.language as tl
192
+
193
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
194
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
195
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
196
+ triton_helpers.set_driver_to_gpu()
197
+
198
+ @triton_heuristics.reduction(
199
+ size_hints={'x': 4096, 'r0_': 4096},
200
+ reduction_hint=ReductionHint.INNER,
201
+ filename=__file__,
202
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
203
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
204
+ )
205
+ @triton.jit
206
+ def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
207
+ rnumel = r0_numel
208
+ RBLOCK: tl.constexpr = R0_BLOCK
209
+ xoffset = tl.program_id(0) * XBLOCK
210
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
211
+ xmask = xindex < xnumel
212
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
213
+ rbase = r0_base
214
+ x0 = xindex
215
+ _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
216
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
217
+ r0_index = r0_offset + r0_base
218
+ r0_mask = r0_index < r0_numel
219
+ roffset = r0_offset
220
+ rindex = r0_index
221
+ r0_1 = r0_index
222
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
223
+ tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
224
+ tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
225
+ tmp2 = tmp0 * tmp1
226
+ tmp3 = tmp2.to(tl.float32)
227
+ tmp5 = tmp4.to(tl.float32)
228
+ tmp6 = tmp3 * tmp5
229
+ tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
230
+ tmp9 = _tmp8 + tmp7
231
+ _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8)
232
+ tmp8 = tl.sum(_tmp8, 1)[:, None]
233
+ tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
234
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
235
+ r0_index = r0_offset + r0_base
236
+ r0_mask = r0_index < r0_numel
237
+ roffset = r0_offset
238
+ rindex = r0_index
239
+ r0_1 = r0_index
240
+ tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
241
+ tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
242
+ tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
243
+ tmp12 = tmp10 * tmp11
244
+ tmp13 = tmp12.to(tl.float32)
245
+ tmp15 = tmp13 * tmp14
246
+ tmp16 = -0.5
247
+ tmp17 = tmp8 * tmp16
248
+ tmp18 = tmp14 * tmp14
249
+ tmp19 = tmp18 * tmp14
250
+ tmp20 = tmp17 * tmp19
251
+ tmp21 = ks0
252
+ tmp22 = tmp21.to(tl.float32)
253
+ tmp23 = (tmp20 / tmp22)
254
+ tmp25 = tmp24.to(tl.float32)
255
+ tmp26 = 2.0
256
+ tmp27 = tmp25 * tmp26
257
+ tmp28 = tmp23 * tmp27
258
+ tmp29 = tmp15 + tmp28
259
+ tmp30 = tmp29.to(tl.float32)
260
+ tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask)
261
+ ''', device_str='cuda')
262
+
263
+
264
+ async_compile.wait(globals())
265
+ del async_compile
266
+
267
+ class Runner:
268
+ def __init__(self, partitions):
269
+ self.partitions = partitions
270
+
271
+ def recursively_apply_fns(self, fns):
272
+ new_callables = []
273
+ for fn, c in zip(fns, self.partitions):
274
+ new_callables.append(fn(c))
275
+ self.partitions = new_callables
276
+
277
+ def call(self, args):
278
+ primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args
279
+ args.clear()
280
+ s47 = primals_1
281
+ s87 = primals_2
282
+ s33 = primals_3
283
+ s82 = primals_6
284
+ assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1))
285
+ assert_size_stride(primals_7, (s33, ), (1, ))
286
+ assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1))
287
+ assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1))
288
+ with torch.cuda._DeviceGuard(1):
289
+ torch.cuda.set_device(1)
290
+ buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32)
291
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
292
+ triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33
293
+ triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32
294
+ stream1 = get_raw_stream(1)
295
+ triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream1)
296
+ buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16)
297
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
298
+ stream1 = get_raw_stream(1)
299
+ triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream1)
300
+ del buf0
301
+ buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16)
302
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add]
303
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87
304
+ stream1 = get_raw_stream(1)
305
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream1)
306
+ del primals_4
307
+ del primals_7
308
+ del rsqrt
309
+ del tangents_1
310
+ return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), )
311
+
312
+ runner = Runner(partitions=[])
313
+ call = runner.call
314
+ recursively_apply_fns = runner.recursively_apply_fns
315
+
316
+
317
+ def benchmark_compiled_module(times=10, repeat=10):
318
+ from torch._dynamo.testing import rand_strided
319
+ from torch._inductor.utils import print_performance
320
+ primals_1 = 2
321
+ primals_2 = 2048
322
+ primals_3 = 4096
323
+ primals_6 = 840433664
324
+ primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
325
+ primals_7 = rand_strided((4096, ), (1, ), device='cuda:1', dtype=torch.bfloat16)
326
+ rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.float32)
327
+ tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
328
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1])
329
+ return print_performance(fn, times=times, repeat=repeat)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ from torch._inductor.wrapper_benchmark import compiled_module_main
334
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/5n/c5n5tizmmgs4cmiupzpopubn6t7eviwt42e3csvt472h63vjwmbu.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 2
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/6a/c6akpy3glququp6suktd5kfns5jol46fmjfl5brlavs7c4zodqhi.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['15_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/x7/cx7fsejzde6zv22nl7w3xpjhybajijgeetsfqi733ibymkptkdrq.py
38
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
39
+ # Source node to ATen node mapping:
40
+ # argmax => argmax
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1]
43
+ # %argmax : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
44
+ # return %argmax
45
+ triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', '''
46
+ import triton
47
+ import triton.language as tl
48
+
49
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
50
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
51
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
52
+ triton_helpers.set_driver_to_gpu()
53
+
54
+ @triton_heuristics.reduction(
55
+ size_hints={'x': 4096, 'r0_': 32768},
56
+ reduction_hint=ReductionHint.INNER,
57
+ filename=__file__,
58
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
59
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
60
+ )
61
+ @triton.jit
62
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
63
+ r0_numel = 32000
64
+ rnumel = r0_numel
65
+ RBLOCK: tl.constexpr = R0_BLOCK
66
+ xoffset = tl.program_id(0) * XBLOCK
67
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
68
+ xmask = xindex < xnumel
69
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
70
+ rbase = r0_base
71
+ x0 = xindex
72
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
73
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
74
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
75
+ r0_index = r0_offset + r0_base
76
+ r0_mask = r0_index < r0_numel
77
+ roffset = r0_offset
78
+ rindex = r0_index
79
+ r0_1 = r0_index
80
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
81
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
82
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
83
+ _tmp2, _tmp2_index, tmp1, rindex
84
+ )
85
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
86
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
87
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
88
+ tmp2 = tmp2_idx[:, None]
89
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
90
+ ''', device_str='cuda')
91
+
92
+
93
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py
94
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
95
+ # Source node to ATen node mapping:
96
+ # argmax_1 => argmax_1
97
+ # Graph fragment:
98
+ # %arg4_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:3" = PlaceHolder[target=arg4_1]
99
+ # %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg4_1, -1), kwargs = {})
100
+ # return %argmax_1
101
+ triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', '''
102
+ import triton
103
+ import triton.language as tl
104
+
105
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
106
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
107
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
108
+ triton_helpers.set_driver_to_gpu()
109
+
110
+ @triton_heuristics.reduction(
111
+ size_hints={'x': 4096, 'r0_': 32768},
112
+ reduction_hint=ReductionHint.INNER,
113
+ filename=__file__,
114
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
115
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
116
+ )
117
+ @triton.jit
118
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
119
+ r0_numel = 32000
120
+ rnumel = r0_numel
121
+ RBLOCK: tl.constexpr = R0_BLOCK
122
+ xoffset = tl.program_id(0) * XBLOCK
123
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
124
+ xmask = xindex < xnumel
125
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
126
+ rbase = r0_base
127
+ x0 = (xindex % ks0)
128
+ x1 = xindex // ks0
129
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
130
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
131
+ x3 = xindex
132
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
133
+ r0_index = r0_offset + r0_base
134
+ r0_mask = r0_index < r0_numel
135
+ roffset = r0_offset
136
+ rindex = r0_index
137
+ r0_2 = r0_index
138
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
139
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
140
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
141
+ _tmp2, _tmp2_index, tmp1, rindex
142
+ )
143
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
144
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
145
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
146
+ tmp2 = tmp2_idx[:, None]
147
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
148
+ ''', device_str='cuda')
149
+
150
+
151
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2o/c2otr5mtbf3tmh4ztmfjn6qv6r3raha22m4sr5h4kaplsk53xtg4.py
152
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
153
+ # Source node to ATen node mapping:
154
+ # eq => eq_2
155
+ # mul => mul_3
156
+ # squeeze => squeeze
157
+ # sum_1 => sum_1
158
+ # Graph fragment:
159
+ # %argmax : Tensor "i64[2, s3][s3, 1]cuda:3" = PlaceHolder[target=argmax]
160
+ # %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:3" = PlaceHolder[target=argmax_1]
161
+ # %arg5_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:3" = PlaceHolder[target=arg5_1]
162
+ # %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
163
+ # %squeeze : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg5_1, -1), kwargs = {})
164
+ # %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {})
165
+ # %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {})
166
+ # return %sum_1
167
+ triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', '''
168
+ import triton
169
+ import triton.language as tl
170
+
171
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
172
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
173
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
174
+ triton_helpers.set_driver_to_gpu()
175
+
176
+ @triton_heuristics.reduction(
177
+ size_hints={'x': 1, 'r0_': 4096},
178
+ reduction_hint=ReductionHint.INNER,
179
+ filename=__file__,
180
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
181
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
182
+ )
183
+ @triton.jit
184
+ def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
185
+ xnumel = 1
186
+ rnumel = r0_numel
187
+ RBLOCK: tl.constexpr = R0_BLOCK
188
+ xoffset = tl.program_id(0) * XBLOCK
189
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
190
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
191
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
192
+ rbase = r0_base
193
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
194
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
195
+ r0_index = r0_offset + r0_base
196
+ r0_mask = r0_index < r0_numel
197
+ roffset = r0_offset
198
+ rindex = r0_index
199
+ r0_0 = r0_index
200
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
201
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
202
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
203
+ tmp2 = tmp0 == tmp1
204
+ tmp3 = tmp2.to(tl.int64)
205
+ tmp5 = tmp3 * tmp4
206
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
207
+ tmp8 = _tmp7 + tmp6
208
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
209
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
210
+ tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None)
211
+ ''', device_str='cuda')
212
+
213
+
214
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cj/ccjx73kwqy3z57a3fjxor5ma5tgytixf7htmrtqxzyfleohcklv4.py
215
+ # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div]
216
+ # Source node to ATen node mapping:
217
+ # clamp_min => clamp_min
218
+ # sum_2 => sum_2
219
+ # truediv => div
220
+ # Graph fragment:
221
+ # %arg7_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:3" = PlaceHolder[target=arg7_1]
222
+ # %sum_1 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_1]
223
+ # %sum_2 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_2]
224
+ # %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg7_1,), kwargs = {})
225
+ # %clamp_min : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {})
226
+ # %div : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {})
227
+ # return %sum_2,%div
228
+ triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', '''
229
+ import triton
230
+ import triton.language as tl
231
+
232
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
233
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
234
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
235
+ triton_helpers.set_driver_to_gpu()
236
+
237
+ @triton_heuristics.reduction(
238
+ size_hints={'x': 1, 'r0_': 4096},
239
+ reduction_hint=ReductionHint.INNER,
240
+ filename=__file__,
241
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
242
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
243
+ )
244
+ @triton.jit
245
+ def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
246
+ xnumel = 1
247
+ rnumel = r0_numel
248
+ RBLOCK: tl.constexpr = R0_BLOCK
249
+ xoffset = tl.program_id(0) * XBLOCK
250
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
251
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
252
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
253
+ rbase = r0_base
254
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
255
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
256
+ r0_index = r0_offset + r0_base
257
+ r0_mask = r0_index < r0_numel
258
+ roffset = r0_offset
259
+ rindex = r0_index
260
+ r0_0 = r0_index
261
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
262
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
263
+ tmp3 = _tmp2 + tmp1
264
+ _tmp2 = tl.where(r0_mask, tmp3, _tmp2)
265
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
266
+ tmp4 = tl.load(in_ptr1 + (0))
267
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1])
268
+ tmp6 = tmp5.to(tl.float32)
269
+ tmp7 = tmp2.to(tl.float32)
270
+ tmp8 = 1e-06
271
+ tmp9 = triton_helpers.maximum(tmp7, tmp8)
272
+ tmp10 = (tmp6 / tmp9)
273
+ tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None)
274
+ ''', device_str='cuda')
275
+
276
+
277
+ async_compile.wait(globals())
278
+ del async_compile
279
+
280
+ class Runner:
281
+ def __init__(self, partitions):
282
+ self.partitions = partitions
283
+
284
+ def recursively_apply_fns(self, fns):
285
+ new_callables = []
286
+ for fn, c in zip(fns, self.partitions):
287
+ new_callables.append(fn(c))
288
+ self.partitions = new_callables
289
+
290
+ def call(self, args):
291
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 = args
292
+ args.clear()
293
+ s3 = arg0_1
294
+ s71 = arg2_1
295
+ s0 = arg3_1
296
+ s14 = arg6_1
297
+ assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1))
298
+ assert_size_stride(arg4_1, (2, s3, 32000), (s71, 32000, 1))
299
+ assert_size_stride(arg5_1, (2, s3, 1), (s3, 1, 1))
300
+ assert_size_stride(arg7_1, (2, s14, 1), (s14, 1, 1))
301
+ with torch.cuda._DeviceGuard(3):
302
+ torch.cuda.set_device(3)
303
+ buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64)
304
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
305
+ triton_red_fused_argmax_0_xnumel = 2*s3
306
+ stream3 = get_raw_stream(3)
307
+ triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream3)
308
+ del arg1_1
309
+ buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64)
310
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
311
+ triton_red_fused_argmax_1_xnumel = 2*s3
312
+ stream3 = get_raw_stream(3)
313
+ triton_red_fused_argmax_1.run(arg4_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream3)
314
+ del arg4_1
315
+ buf2 = empty_strided_cuda((), (), torch.int64)
316
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
317
+ triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3
318
+ stream3 = get_raw_stream(3)
319
+ triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg5_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream3)
320
+ del arg5_1
321
+ del buf0
322
+ del buf1
323
+ buf4 = empty_strided_cuda((), (), torch.float32)
324
+ # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div]
325
+ triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14
326
+ stream3 = get_raw_stream(3)
327
+ triton_red_fused_clamp_min_div_sum_3.run(arg7_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream3)
328
+ del arg7_1
329
+ del buf2
330
+ return (buf4, )
331
+
332
+ runner = Runner(partitions=[])
333
+ call = runner.call
334
+ recursively_apply_fns = runner.recursively_apply_fns
335
+
336
+
337
+ def benchmark_compiled_module(times=10, repeat=10):
338
+ from torch._dynamo.testing import rand_strided
339
+ from torch._inductor.utils import print_performance
340
+ arg0_1 = 2014
341
+ arg1_1 = rand_strided((2, 2014, 32000), (64448000, 32000, 1), device='cuda:3', dtype=torch.bfloat16)
342
+ arg2_1 = 64672000
343
+ arg3_1 = 32000
344
+ arg4_1 = rand_strided((2, 2014, 32000), (64672000, 32000, 1), device='cuda:3', dtype=torch.float32)
345
+ arg5_1 = rand_strided((2, 2014, 1), (2014, 1, 1), device='cuda:3', dtype=torch.int64)
346
+ arg6_1 = 2014
347
+ arg7_1 = rand_strided((2, 2014, 1), (2014, 1, 1), device='cuda:3', dtype=torch.int64)
348
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1])
349
+ return print_performance(fn, times=times, repeat=repeat)
350
+
351
+
352
+ if __name__ == "__main__":
353
+ from torch._inductor.wrapper_benchmark import compiled_module_main
354
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4096},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 2176
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.full([1], 0, tl.int32)
25
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/73/c73fwyrbq77x2xei6vgy66r34rmoo6pilk6kf2iqehy6oksjmbwj.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['9_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hj/chjn2h2lagxtgealz3aitqmfnksszmnt7q4hnsw5vu6risac6dmq.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1]
44
+ # %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_3]
45
+ # %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_5]
46
+ # %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1]
48
+ # %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9]
49
+ # %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_7]
50
+ # %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11]
51
+ # %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_13]
52
+ # %primals_10 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_10]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = False
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 2
147
+ KV_LEN = ks0
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 16*ks1
185
+ stride_kv_idx_m = ks1
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = False
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = False
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args
625
+ args.clear()
626
+ s0 = primals_2
627
+ s43 = primals_4
628
+ s72 = primals_6
629
+ s71 = primals_8
630
+ s4 = primals_12
631
+ s56 = primals_14
632
+ s84 = primals_16
633
+ s99 = primals_18
634
+ s6 = primals_20
635
+ assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
636
+ assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
637
+ assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
638
+ assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1))
639
+ assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1))
640
+ assert_size_stride(primals_10, (2, ), (1, ))
641
+ assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1))
642
+ assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1))
643
+ assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1))
644
+ assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1))
645
+ assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1))
646
+ assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1))
647
+ with torch.cuda._DeviceGuard(3):
648
+ torch.cuda.set_device(3)
649
+ buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
650
+ buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
651
+ buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
652
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
653
+ stream3 = get_raw_stream(3)
654
+ triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream3)
655
+ del buf1
656
+ return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, )
657
+
658
+ runner = Runner(partitions=[])
659
+ call = runner.call
660
+ recursively_apply_fns = runner.recursively_apply_fns
661
+
662
+
663
+ def benchmark_compiled_module(times=10, repeat=10):
664
+ from torch._dynamo.testing import rand_strided
665
+ from torch._inductor.utils import print_performance
666
+ primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
667
+ primals_2 = 4096
668
+ primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16)
669
+ primals_4 = 4096
670
+ primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16)
671
+ primals_6 = 32
672
+ primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32)
673
+ primals_8 = 4096
674
+ primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
675
+ primals_10 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64)
676
+ primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
677
+ primals_12 = 32
678
+ primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32)
679
+ primals_14 = 32
680
+ primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32)
681
+ primals_16 = 32
682
+ primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32)
683
+ primals_18 = 32
684
+ primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32)
685
+ primals_20 = 32
686
+ primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32)
687
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21])
688
+ return print_performance(fn, times=times, repeat=repeat)
689
+
690
+
691
+ if __name__ == "__main__":
692
+ from torch._inductor.wrapper_benchmark import compiled_module_main
693
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/bj/cbj72m23cmcn2yjoxrp4vabc2f76gw727jcpbi4y5oidokqenki5.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 2048},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.full([1], 0, tl.int32)
24
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/ce/ccez434a7hzyympuosxgkqmu5zncaqowipase2enwlehm3k7igny.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 512, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 512
20
+ r0_numel = 16384
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x1 = ((xindex // 16) % 16)
29
+ x0 = (xindex % 16)
30
+ x2 = xindex // 256
31
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
32
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
33
+ x6 = xindex
34
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
35
+ r0_index = r0_offset + r0_base
36
+ r0_mask = r0_index < r0_numel
37
+ roffset = r0_offset
38
+ rindex = r0_index
39
+ r0_4 = r0_index // 128
40
+ r0_3 = (r0_index % 128)
41
+ tmp0 = r0_4 + 128*x1
42
+ tmp1 = r0_3 + 128*x0
43
+ tmp2 = tmp0 >= tmp1
44
+ tmp4 = tmp1 < tmp3
45
+ tmp5 = tmp0 < tmp3
46
+ tmp6 = tmp4 & tmp5
47
+ tmp7 = tmp2 & tmp6
48
+ tmp8 = tl.full([1, 1], False, tl.int1)
49
+ tmp9 = tmp8 | tmp7
50
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
51
+ tmp11 = tmp1 >= tmp10
52
+ tmp12 = tmp11 & tmp4
53
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
54
+ tmp14 = (tmp13 % tmp10)
55
+ tmp15 = tl.full([1, 1], 0, tl.int32)
56
+ tmp16 = tmp14 != tmp15
57
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
58
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
59
+ tmp19 = tmp17 != tmp18
60
+ tmp20 = tmp16 & tmp19
61
+ tmp21 = tmp14 + tmp10
62
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
63
+ tmp23 = tl.full([1, 1], 0, tl.int64)
64
+ tmp24 = tmp22 == tmp23
65
+ tmp25 = tmp12 & tmp24
66
+ tmp26 = tmp9 | tmp25
67
+ tmp27 = tmp26.to(tl.int64)
68
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
69
+ tmp30 = _tmp29 + tmp28
70
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
71
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
72
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/cm/c07273381756209821a51449b0970c31551d44d464333fbb852c9fc655362c46.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"}
SpecForge-ext/cache/compiled_kernels/cr/ccr5s7nffy4cqd7a3lcq3cnv2prruzwzc7chchf776jguuqqh5bc.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/d4/cd4qgy6v3vbg74qytdbsmdpamjzb6kuwcsiu7yfpml4f7zxhf4j3.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['0_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zw/czwh6tgkq6scdstgzueb3goqqnllndikoasj2i2iehu2qyvoccwt.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg0_1 : Tensor "bf16[8, 2048, 151936][311164928, 151936, 1]cuda:6" = PlaceHolder[target=arg0_1]
47
+ # %argmax : Tensor "i64[8, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax]
48
+ # %arg1_1 : Tensor "b8[151936][1]cuda:6" = PlaceHolder[target=arg1_1]
49
+ # %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg2_1]
50
+ # %argmax : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {})
55
+ # return %argmax,%mul
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 16384, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ xnumel = 16384
75
+ r0_numel = 151936
76
+ rnumel = r0_numel
77
+ RBLOCK: tl.constexpr = R0_BLOCK
78
+ xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
79
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
80
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
81
+ r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
82
+ rbase = r0_base
83
+ x0 = xindex
84
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
85
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
86
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
87
+ r0_index = r0_offset + r0_base
88
+ r0_mask = r0_index < r0_numel
89
+ roffset = r0_offset
90
+ rindex = r0_index
91
+ r0_1 = r0_index
92
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
93
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
94
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
95
+ _tmp2, _tmp2_index, tmp1, rindex
96
+ )
97
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
98
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
99
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
100
+ tmp2 = tmp2_idx[:, None]
101
+ tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
102
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
103
+ tmp4 = tmp2 + tmp3
104
+ tmp5 = tmp2 < 0
105
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
106
+ tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936")
107
+ tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1)
108
+ tmp9 = tmp8.to(tl.int32)
109
+ tmp10 = tmp9.to(tl.int64)
110
+ tmp12 = tmp10 * tmp11
111
+ tl.debug_barrier()
112
+ tl.store(in_out_ptr0 + (x0), tmp12, None)
113
+ ''', device_str='cuda')
114
+
115
+
116
+ async_compile.wait(globals())
117
+ del async_compile
118
+
119
+ class Runner:
120
+ def __init__(self, partitions):
121
+ self.partitions = partitions
122
+
123
+ def recursively_apply_fns(self, fns):
124
+ new_callables = []
125
+ for fn, c in zip(fns, self.partitions):
126
+ new_callables.append(fn(c))
127
+ self.partitions = new_callables
128
+
129
+ def call(self, args):
130
+ arg0_1, arg1_1, arg2_1 = args
131
+ args.clear()
132
+ assert_size_stride(arg0_1, (8, 2048, 151936), (311164928, 151936, 1))
133
+ assert_size_stride(arg1_1, (151936, ), (1, ))
134
+ assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1))
135
+ with torch.cuda._DeviceGuard(6):
136
+ torch.cuda.set_device(6)
137
+ buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64)
138
+ buf1 = reinterpret_tensor(buf0, (8, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse
139
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
140
+ stream6 = get_raw_stream(6)
141
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 16384, 151936, stream=stream6)
142
+ del arg0_1
143
+ del arg1_1
144
+ del arg2_1
145
+ return (buf1, )
146
+
147
+ runner = Runner(partitions=[])
148
+ call = runner.call
149
+ recursively_apply_fns = runner.recursively_apply_fns
150
+
151
+
152
+ def benchmark_compiled_module(times=10, repeat=10):
153
+ from torch._dynamo.testing import rand_strided
154
+ from torch._inductor.utils import print_performance
155
+ arg0_1 = rand_strided((8, 2048, 151936), (311164928, 151936, 1), device='cuda:6', dtype=torch.bfloat16)
156
+ arg1_1 = rand_strided((151936, ), (1, ), device='cuda:6', dtype=torch.bool)
157
+ arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64)
158
+ fn = lambda: call([arg0_1, arg1_1, arg2_1])
159
+ return print_performance(fn, times=times, repeat=repeat)
160
+
161
+
162
+ if __name__ == "__main__":
163
+ from torch._inductor.wrapper_benchmark import compiled_module_main
164
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/e4/9676ab333d44c7c3eec122b806e4fc2028468bcd55a94115e7272e322515d58c.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"}
SpecForge-ext/cache/compiled_kernels/e6/ce65awkeaxcxjqfa27pcogsy3sjyxwzxjt3w2rte76m7izgybp2s.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 512, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 512
20
+ r0_numel = 16384
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x1 = ((xindex // 16) % 16)
29
+ x0 = (xindex % 16)
30
+ x2 = xindex // 256
31
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
32
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
33
+ x6 = xindex
34
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
35
+ r0_index = r0_offset + r0_base
36
+ r0_mask = r0_index < r0_numel
37
+ roffset = r0_offset
38
+ rindex = r0_index
39
+ r0_4 = r0_index // 128
40
+ r0_3 = (r0_index % 128)
41
+ tmp0 = r0_4 + 128*x1
42
+ tmp1 = r0_3 + 128*x0
43
+ tmp2 = tmp0 >= tmp1
44
+ tmp4 = tmp1 < tmp3
45
+ tmp5 = tmp0 < tmp3
46
+ tmp6 = tmp4 & tmp5
47
+ tmp7 = tmp2 & tmp6
48
+ tmp8 = tl.full([1, 1], False, tl.int1)
49
+ tmp9 = tmp8 | tmp7
50
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
51
+ tmp11 = tmp1 >= tmp10
52
+ tmp12 = tmp11 & tmp4
53
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
54
+ tmp14 = (tmp13 % tmp10)
55
+ tmp15 = tl.full([1, 1], 0, tl.int32)
56
+ tmp16 = tmp14 != tmp15
57
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
58
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
59
+ tmp19 = tmp17 != tmp18
60
+ tmp20 = tmp16 & tmp19
61
+ tmp21 = tmp14 + tmp10
62
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
63
+ tmp23 = tl.full([1, 1], 0, tl.int64)
64
+ tmp24 = tmp22 == tmp23
65
+ tmp25 = tmp12 & tmp24
66
+ tmp26 = tmp9 | tmp25
67
+ tmp27 = tmp26.to(tl.int64)
68
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
69
+ tmp30 = _tmp29 + tmp28
70
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
71
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
72
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/ea/ceahttlkg35qey3ao6gw65rzzv3bop5xwrthhogt6nyvyw3rece5.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/eg/32c642d9b91cf1b4fe91f745e14ee86eccc7f5783759ca318976f5af47c474c9.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"}
SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/eg/cegtc7d6fywtdtf2rerfuwwwn7fajohhh2ltqlvjjvdguxabbva4.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['14_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/p3/cp3lt4qtmnlmp6kb7cx5zc6bshlrxlbfed2c4ciyoiapxknraax3.py
38
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
39
+ # Source node to ATen node mapping:
40
+ # argmax => argmax
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[8, s3, 32000][32000*s3, 32000, 1]cuda:0" = PlaceHolder[target=arg1_1]
43
+ # %argmax : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
44
+ # return %argmax
45
+ triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', '''
46
+ import triton
47
+ import triton.language as tl
48
+
49
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
50
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
51
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
52
+ triton_helpers.set_driver_to_gpu()
53
+
54
+ @triton_heuristics.reduction(
55
+ size_hints={'x': 16384, 'r0_': 32768},
56
+ reduction_hint=ReductionHint.INNER,
57
+ filename=__file__,
58
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
59
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
60
+ )
61
+ @triton.jit
62
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
63
+ r0_numel = 32000
64
+ rnumel = r0_numel
65
+ RBLOCK: tl.constexpr = R0_BLOCK
66
+ xoffset = tl.program_id(0) * XBLOCK
67
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
68
+ xmask = xindex < xnumel
69
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
70
+ rbase = r0_base
71
+ x0 = xindex
72
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
73
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
74
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
75
+ r0_index = r0_offset + r0_base
76
+ r0_mask = r0_index < r0_numel
77
+ roffset = r0_offset
78
+ rindex = r0_index
79
+ r0_1 = r0_index
80
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
81
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
82
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
83
+ _tmp2, _tmp2_index, tmp1, rindex
84
+ )
85
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
86
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
87
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
88
+ tmp2 = tmp2_idx[:, None]
89
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
90
+ ''', device_str='cuda')
91
+
92
+
93
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/eu/ceui6qrb2t3lmzs3ljrqtcomt4b2q6svzo24j6mmryaiovr6kp7y.py
94
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
95
+ # Source node to ATen node mapping:
96
+ # argmax_1 => argmax_1
97
+ # Graph fragment:
98
+ # %arg3_1 : Tensor "f32[8, s3, 32000][s71, 32000, 1]cuda:0" = PlaceHolder[target=arg3_1]
99
+ # %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {})
100
+ # return %argmax_1
101
+ triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', '''
102
+ import triton
103
+ import triton.language as tl
104
+
105
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
106
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
107
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
108
+ triton_helpers.set_driver_to_gpu()
109
+
110
+ @triton_heuristics.reduction(
111
+ size_hints={'x': 16384, 'r0_': 32768},
112
+ reduction_hint=ReductionHint.DEFAULT,
113
+ filename=__file__,
114
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
115
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
116
+ )
117
+ @triton.jit
118
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
119
+ r0_numel = 32000
120
+ rnumel = r0_numel
121
+ RBLOCK: tl.constexpr = R0_BLOCK
122
+ xoffset = tl.program_id(0) * XBLOCK
123
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
124
+ xmask = xindex < xnumel
125
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
126
+ rbase = r0_base
127
+ x0 = (xindex % ks0)
128
+ x1 = xindex // ks0
129
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
130
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
131
+ x3 = xindex
132
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
133
+ r0_index = r0_offset + r0_base
134
+ r0_mask = r0_index < r0_numel
135
+ roffset = r0_offset
136
+ rindex = r0_index
137
+ r0_2 = r0_index
138
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
139
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
140
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
141
+ _tmp2, _tmp2_index, tmp1, rindex
142
+ )
143
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
144
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
145
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
146
+ tmp2 = tmp2_idx[:, None]
147
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
148
+ ''', device_str='cuda')
149
+
150
+
151
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py
152
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
153
+ # Source node to ATen node mapping:
154
+ # eq => eq_2
155
+ # mul => mul_7
156
+ # squeeze => squeeze
157
+ # sum_1 => sum_1
158
+ # Graph fragment:
159
+ # %argmax : Tensor "i64[8, s3][s3, 1]cuda:0" = PlaceHolder[target=argmax]
160
+ # %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:0" = PlaceHolder[target=argmax_1]
161
+ # %arg4_1 : Tensor "i64[8, s3, 1][s3, 1, 1]cuda:0" = PlaceHolder[target=arg4_1]
162
+ # %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
163
+ # %squeeze : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {})
164
+ # %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {})
165
+ # %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {})
166
+ # return %buf3
167
+ triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', '''
168
+ import triton
169
+ import triton.language as tl
170
+
171
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
172
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
173
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
174
+ triton_helpers.set_driver_to_gpu()
175
+
176
+ @triton_heuristics.reduction(
177
+ size_hints={'x': 2, 'r0_': 8192},
178
+ reduction_hint=ReductionHint.INNER,
179
+ filename=__file__,
180
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
181
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
182
+ )
183
+ @triton.jit
184
+ def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
185
+ xnumel = 2
186
+ rnumel = r0_numel
187
+ RBLOCK: tl.constexpr = R0_BLOCK
188
+ xoffset = tl.program_id(0) * XBLOCK
189
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
190
+ xmask = xindex < xnumel
191
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
192
+ rbase = r0_base
193
+ x0 = xindex
194
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
195
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
196
+ r0_index = r0_offset + r0_base
197
+ r0_mask = r0_index < r0_numel
198
+ roffset = r0_offset
199
+ rindex = r0_index
200
+ r0_1 = r0_index
201
+ tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
202
+ tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
203
+ tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
204
+ tmp2 = tmp0 == tmp1
205
+ tmp3 = tmp2.to(tl.int64)
206
+ tmp5 = tmp3 * tmp4
207
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
208
+ tmp8 = _tmp7 + tmp6
209
+ _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7)
210
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
211
+ tl.store(out_ptr0 + (x0), tmp7, xmask)
212
+ ''', device_str='cuda')
213
+
214
+
215
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xe/cxe74qazmcwxkyh3xlgupaetbeksmhlptcogpgxu7tfvr4arcob6.py
216
+ # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum]
217
+ # Source node to ATen node mapping:
218
+ # sum_2 => sum_2
219
+ # Graph fragment:
220
+ # %arg6_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg6_1]
221
+ # %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {})
222
+ # return %buf5
223
+ triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', '''
224
+ import triton
225
+ import triton.language as tl
226
+
227
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
228
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
229
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
230
+ triton_helpers.set_driver_to_gpu()
231
+
232
+ @triton_heuristics.reduction(
233
+ size_hints={'x': 2, 'r0_': 8192},
234
+ reduction_hint=ReductionHint.INNER,
235
+ filename=__file__,
236
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
237
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
238
+ )
239
+ @triton.jit
240
+ def triton_red_fused_sum_3(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
241
+ xnumel = 2
242
+ rnumel = r0_numel
243
+ RBLOCK: tl.constexpr = R0_BLOCK
244
+ xoffset = tl.program_id(0) * XBLOCK
245
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
246
+ xmask = xindex < xnumel
247
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
248
+ rbase = r0_base
249
+ x0 = xindex
250
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
251
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
252
+ r0_index = r0_offset + r0_base
253
+ r0_mask = r0_index < r0_numel
254
+ roffset = r0_offset
255
+ rindex = r0_index
256
+ r0_1 = r0_index
257
+ tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
258
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
259
+ tmp3 = _tmp2 + tmp1
260
+ _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
261
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
262
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
263
+ ''', device_str='cuda')
264
+
265
+
266
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/az/cazc4elakae7tgyuygha6gaxmfo4ouj4mtb6kxylbj7524jvkqaz.py
267
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div]
268
+ # Source node to ATen node mapping:
269
+ # clamp_min => clamp_min
270
+ # eq => eq_2
271
+ # mul => mul_7
272
+ # squeeze => squeeze
273
+ # sum_1 => sum_1
274
+ # sum_2 => sum_2
275
+ # truediv => div
276
+ # Graph fragment:
277
+ # %buf3 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=buf3]
278
+ # %buf5 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=buf5]
279
+ # %sum_1 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_1]
280
+ # %sum_2 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_2]
281
+ # %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
282
+ # %squeeze : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {})
283
+ # %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {})
284
+ # %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {})
285
+ # %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {})
286
+ # %clamp_min : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {})
287
+ # %div : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {})
288
+ # return %sum_1,%sum_2,%div
289
+ triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', '''
290
+ import triton
291
+ import triton.language as tl
292
+
293
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
294
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
295
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
296
+ triton_helpers.set_driver_to_gpu()
297
+
298
+ @triton_heuristics.persistent_reduction(
299
+ size_hints={'x': 1, 'r0_': 2},
300
+ reduction_hint=ReductionHint.INNER,
301
+ filename=__file__,
302
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
303
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}}
304
+ )
305
+ @triton.jit
306
+ def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr):
307
+ xnumel = 1
308
+ r0_numel = 2
309
+ R0_BLOCK: tl.constexpr = 2
310
+ rnumel = r0_numel
311
+ RBLOCK: tl.constexpr = R0_BLOCK
312
+ xoffset = tl.program_id(0) * XBLOCK
313
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
314
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
315
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
316
+ r0_offset = 0
317
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
318
+ roffset = r0_offset
319
+ rindex = r0_index
320
+ r0_0 = r0_index
321
+ tmp0 = tl.load(in_ptr0 + (r0_0), None)
322
+ tmp4 = tl.load(in_ptr1 + (r0_0), None)
323
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
324
+ tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64)
325
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
326
+ tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64)
327
+ tmp8 = tmp3.to(tl.float32)
328
+ tmp9 = tmp7.to(tl.float32)
329
+ tmp10 = 1e-06
330
+ tmp11 = triton_helpers.maximum(tmp9, tmp10)
331
+ tmp12 = (tmp8 / tmp11)
332
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None)
333
+ ''', device_str='cuda')
334
+
335
+
336
+ async_compile.wait(globals())
337
+ del async_compile
338
+
339
+ class Runner:
340
+ def __init__(self, partitions):
341
+ self.partitions = partitions
342
+
343
+ def recursively_apply_fns(self, fns):
344
+ new_callables = []
345
+ for fn, c in zip(fns, self.partitions):
346
+ new_callables.append(fn(c))
347
+ self.partitions = new_callables
348
+
349
+ def call(self, args):
350
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args
351
+ args.clear()
352
+ s3 = arg0_1
353
+ s71 = arg2_1
354
+ s14 = arg5_1
355
+ assert_size_stride(arg1_1, (8, s3, 32000), (32000*s3, 32000, 1))
356
+ assert_size_stride(arg3_1, (8, s3, 32000), (s71, 32000, 1))
357
+ assert_size_stride(arg4_1, (8, s3, 1), (s3, 1, 1))
358
+ assert_size_stride(arg6_1, (8, s14, 1), (s14, 1, 1))
359
+ with torch.cuda._DeviceGuard(0):
360
+ torch.cuda.set_device(0)
361
+ buf0 = empty_strided_cuda((8, s3), (s3, 1), torch.int64)
362
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
363
+ triton_red_fused_argmax_0_xnumel = 8*s3
364
+ stream0 = get_raw_stream(0)
365
+ triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream0)
366
+ del arg1_1
367
+ buf1 = empty_strided_cuda((8, s3), (s3, 1), torch.int64)
368
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
369
+ triton_red_fused_argmax_1_xnumel = 8*s3
370
+ stream0 = get_raw_stream(0)
371
+ triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream0)
372
+ del arg3_1
373
+ buf3 = empty_strided_cuda((2, ), (1, ), torch.int64)
374
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum]
375
+ triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 4*s3
376
+ stream0 = get_raw_stream(0)
377
+ triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf3, s3, 2, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream0)
378
+ del arg4_1
379
+ del buf0
380
+ del buf1
381
+ buf5 = empty_strided_cuda((2, ), (1, ), torch.int64)
382
+ # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum]
383
+ triton_red_fused_sum_3_r0_numel = 4*s14
384
+ stream0 = get_raw_stream(0)
385
+ triton_red_fused_sum_3.run(arg6_1, buf5, s14, 2, triton_red_fused_sum_3_r0_numel, stream=stream0)
386
+ del arg6_1
387
+ buf7 = empty_strided_cuda((), (), torch.float32)
388
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div]
389
+ stream0 = get_raw_stream(0)
390
+ triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream0)
391
+ del buf3
392
+ del buf5
393
+ return (buf7, )
394
+
395
+ runner = Runner(partitions=[])
396
+ call = runner.call
397
+ recursively_apply_fns = runner.recursively_apply_fns
398
+
399
+
400
+ def benchmark_compiled_module(times=10, repeat=10):
401
+ from torch._dynamo.testing import rand_strided
402
+ from torch._inductor.utils import print_performance
403
+ arg0_1 = 2009
404
+ arg1_1 = rand_strided((8, 2009, 32000), (64288000, 32000, 1), device='cuda:0', dtype=torch.bfloat16)
405
+ arg2_1 = 64512000
406
+ arg3_1 = rand_strided((8, 2009, 32000), (64512000, 32000, 1), device='cuda:0', dtype=torch.float32)
407
+ arg4_1 = rand_strided((8, 2009, 1), (2009, 1, 1), device='cuda:0', dtype=torch.int64)
408
+ arg5_1 = 2009
409
+ arg6_1 = rand_strided((8, 2009, 1), (2009, 1, 1), device='cuda:0', dtype=torch.int64)
410
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1])
411
+ return print_performance(fn, times=times, repeat=repeat)
412
+
413
+
414
+ if __name__ == "__main__":
415
+ from torch._inductor.wrapper_benchmark import compiled_module_main
416
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/el/celb4xosanyf3m2sx6v3t54w4bgkx65m4lb2newu7nggikw6jbxj.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x1 = xindex // ks0
27
+ x0 = (xindex % ks0)
28
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
29
+ x3 = xindex
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_2 = r0_index
36
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
37
+ tmp1 = ks1*ks2
38
+ tmp2 = tmp0 < tmp1
39
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp5 = tmp4.to(tl.float32)
42
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
43
+ tmp7 = tmp5 * tmp6
44
+ tmp8 = tmp7.to(tl.float32)
45
+ tmp9 = tmp3 * tmp8
46
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
47
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
48
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
49
+ tmp14 = _tmp13 + tmp12
50
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
51
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
52
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
SpecForge-ext/cache/compiled_kernels/et/cet6lrlwcthdi3by3ttnab2z245l4q55x7tvdilkic6xqjfjlixg.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ r0_numel = 4096
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_0 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
38
+ tmp2 = tmp0 == tmp1
39
+ tmp3 = tmp2.to(tl.int64)
40
+ tmp5 = tmp3 * tmp4
41
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
42
+ tmp8 = _tmp7 + tmp6
43
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
44
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
45
+ _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
46
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
47
+ r0_index = r0_offset + r0_base
48
+ r0_mask = r0_index < r0_numel
49
+ roffset = r0_offset
50
+ rindex = r0_index
51
+ r0_0 = r0_index
52
+ tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
53
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
54
+ tmp12 = _tmp11 + tmp10
55
+ _tmp11 = tl.where(r0_mask, tmp12, _tmp11)
56
+ tmp11 = tl.sum(_tmp11, 1)[:, None]
57
+ tmp13 = tmp7.to(tl.float32)
58
+ tmp14 = tmp11.to(tl.float32)
59
+ tmp15 = 1e-06
60
+ tmp16 = triton_helpers.maximum(tmp14, tmp15)
61
+ tmp17 = (tmp13 / tmp16)
62
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None)
SpecForge-ext/cache/compiled_kernels/ey/cey3ar6s7f2t62buescu5cctxdhf6hmbv3ps5d3tmh235oaj3fj6.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4194304},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/ey/ceyifglcwq5k7zog6faauufd7zk5fsacgjqk43m6vpya73dy3l62.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['5_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wr/cwrt3cdfiri2z4jso4afypedtru4cdebpo556yzgrqawlufswk26.py
38
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # and_2 => bitwise_and_1
41
+ # and_3 => bitwise_and_2
42
+ # and_4 => bitwise_and_3, view_8
43
+ # b => iota
44
+ # batched_outputs_2 => view_9
45
+ # causal_mask => ge, view
46
+ # diagnol_mask => eq
47
+ # index => index
48
+ # index_1 => index_1
49
+ # index_2 => index_2
50
+ # lt => lt, view_1
51
+ # lt_1 => lt_1, view_2
52
+ # m => iota_2
53
+ # mask_2 => view_10
54
+ # mask_3 => permute
55
+ # mask_block_sum => sum_1
56
+ # n => iota_3
57
+ # padding_mask => bitwise_and, view_3, view_4
58
+ # padding_mask_1 => lt_2, view_6
59
+ # remainder => remainder
60
+ # remainder_1 => remainder_1
61
+ # result_1 => bitwise_or, full_default
62
+ # result_2 => bitwise_or_1
63
+ # sub => sub, view_7
64
+ # suffix_mask => ge_1
65
+ # Graph fragment:
66
+ # %arg0_1 : Tensor "i64[8][1]cuda:6" = PlaceHolder[target=arg0_1]
67
+ # %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:6, pin_memory: False})
68
+ # %iota_2 : Tensor "i64[2048][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
69
+ # %view : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
70
+ # %iota_3 : Tensor "i64[2048][1]cuda:6"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
71
+ # %ge : Tensor "b8[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {})
72
+ # %iota : Tensor "i64[8][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
73
+ # %index : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
74
+ # %view_1 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {})
75
+ # %lt : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {})
76
+ # %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {})
77
+ # %index_1 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
78
+ # %view_2 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {})
79
+ # %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {})
80
+ # %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {})
81
+ # %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {})
82
+ # %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {})
83
+ # %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {})
84
+ # %ge_1 : Tensor "b8[2048][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {})
85
+ # %remainder : Tensor "i64[2048][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {})
86
+ # %index_2 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
87
+ # %view_6 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {})
88
+ # %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {})
89
+ # %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {})
90
+ # %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {})
91
+ # %view_7 : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
92
+ # %sub : Tensor "i64[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {})
93
+ # %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {})
94
+ # %eq : Tensor "b8[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {})
95
+ # %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {})
96
+ # %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {})
97
+ # %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {})
98
+ # %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {})
99
+ # %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {})
100
+ # %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {})
101
+ # return %sum_1
102
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', '''
103
+ import triton
104
+ import triton.language as tl
105
+
106
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
107
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
108
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
109
+ triton_helpers.set_driver_to_gpu()
110
+
111
+ @triton_heuristics.reduction(
112
+ size_hints={'x': 2048, 'r0_': 16384},
113
+ reduction_hint=ReductionHint.INNER,
114
+ filename=__file__,
115
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
116
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}}
117
+ )
118
+ @triton.jit
119
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
120
+ xnumel = 2048
121
+ r0_numel = 16384
122
+ rnumel = r0_numel
123
+ RBLOCK: tl.constexpr = R0_BLOCK
124
+ xoffset = tl.program_id(0) * XBLOCK
125
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
126
+ xmask = xindex < xnumel
127
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
128
+ rbase = r0_base
129
+ x1 = ((xindex // 16) % 16)
130
+ x0 = (xindex % 16)
131
+ x2 = xindex // 256
132
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
133
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
134
+ x6 = xindex
135
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
136
+ r0_index = r0_offset + r0_base
137
+ r0_mask = r0_index < r0_numel
138
+ roffset = r0_offset
139
+ rindex = r0_index
140
+ r0_4 = r0_index // 128
141
+ r0_3 = (r0_index % 128)
142
+ tmp0 = r0_4 + 128*x1
143
+ tmp1 = r0_3 + 128*x0
144
+ tmp2 = tmp0 >= tmp1
145
+ tmp4 = tmp1 < tmp3
146
+ tmp5 = tmp0 < tmp3
147
+ tmp6 = tmp4 & tmp5
148
+ tmp7 = tmp2 & tmp6
149
+ tmp8 = tl.full([1, 1], False, tl.int1)
150
+ tmp9 = tmp8 | tmp7
151
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
152
+ tmp11 = tmp1 >= tmp10
153
+ tmp12 = tmp11 & tmp4
154
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
155
+ tmp14 = (tmp13 % tmp10)
156
+ tmp15 = tl.full([1, 1], 0, tl.int32)
157
+ tmp16 = tmp14 != tmp15
158
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
159
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
160
+ tmp19 = tmp17 != tmp18
161
+ tmp20 = tmp16 & tmp19
162
+ tmp21 = tmp14 + tmp10
163
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
164
+ tmp23 = tl.full([1, 1], 0, tl.int64)
165
+ tmp24 = tmp22 == tmp23
166
+ tmp25 = tmp12 & tmp24
167
+ tmp26 = tmp9 | tmp25
168
+ tmp27 = tmp26.to(tl.int64)
169
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
170
+ tmp30 = _tmp29 + tmp28
171
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
172
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
173
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
174
+ ''', device_str='cuda')
175
+
176
+
177
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hw/chwm44jdqtovypwqknevqvz2d2xrazceb4ci2erooz4tahlocvzv.py
178
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
179
+ # Source node to ATen node mapping:
180
+ # dense_mask_4 => full_default_4
181
+ # Graph fragment:
182
+ # %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
183
+ # return %index_put_1
184
+ triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', '''
185
+ import triton
186
+ import triton.language as tl
187
+
188
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
189
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
190
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
191
+ triton_helpers.set_driver_to_gpu()
192
+
193
+ @triton_heuristics.pointwise(
194
+ size_hints={'x': 4096},
195
+ filename=__file__,
196
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
197
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}},
198
+ min_elem_per_thread=0
199
+ )
200
+ @triton.jit
201
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
202
+ xnumel = 2176
203
+ xoffset = tl.program_id(0) * XBLOCK
204
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
205
+ xmask = xindex < xnumel
206
+ x0 = xindex
207
+ tmp0 = tl.full([1], 0, tl.int32)
208
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
209
+ ''', device_str='cuda')
210
+
211
+
212
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rk/crkxpwiwhzkvun7i5d2pegofthyfijn5wygwnaev3twwlrbuojqe.py
213
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
214
+ # Source node to ATen node mapping:
215
+ # arange_4 => iota_4
216
+ # arange_6 => iota_8
217
+ # child_3 => convert_element_type_3
218
+ # child_4 => convert_element_type_4
219
+ # child_7 => convert_element_type_6
220
+ # child_8 => convert_element_type_7
221
+ # col_indices => sort
222
+ # col_indices_1 => sort_1
223
+ # col_range => iota_5
224
+ # col_range_1 => iota_9
225
+ # dense_mask => convert_element_type_2
226
+ # dense_mask_1 => convert_element_type_5
227
+ # dense_mask_2 => full_default_1
228
+ # dense_mask_4 => full_default_4
229
+ # full_blocks => eq_1
230
+ # full_blocks_1 => convert_element_type_1
231
+ # gt => gt
232
+ # index_mask => lt_4
233
+ # index_mask_1 => lt_5
234
+ # lt_3 => lt_3
235
+ # num_blocks_in_row => sum_2
236
+ # num_blocks_in_row_1 => sum_3
237
+ # partial_blocks => bitwise_and_4
238
+ # partial_blocks_1 => convert_element_type
239
+ # row_indices => unsqueeze
240
+ # row_indices_1 => unsqueeze_7
241
+ # setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6
242
+ # setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9
243
+ # unsqueeze_1 => unsqueeze_1
244
+ # unsqueeze_3 => unsqueeze_8
245
+ # valid_indices => full_default_2, where
246
+ # valid_indices_1 => full_default_5, where_1
247
+ # Graph fragment:
248
+ # %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=sum_1]
249
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_2]
250
+ # %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_3]
251
+ # %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=buf2]
252
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_3]
253
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_4]
254
+ # %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6" = PlaceHolder[target=index_put]
255
+ # %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=buf4]
256
+ # %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_6]
257
+ # %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_7]
258
+ # %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6" = PlaceHolder[target=index_put_1]
259
+ # %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
260
+ # %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {})
261
+ # %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {})
262
+ # %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {})
263
+ # %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {})
264
+ # %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True})
265
+ # %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {})
266
+ # %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {})
267
+ # %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {})
268
+ # %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True})
269
+ # %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
270
+ # %iota_7 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
271
+ # %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {})
272
+ # %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {})
273
+ # %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {})
274
+ # %iota_6 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
275
+ # %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {})
276
+ # %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {})
277
+ # %iota_4 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False})
278
+ # %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {})
279
+ # %iota_5 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False})
280
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {})
281
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {})
282
+ # %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {})
283
+ # %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {})
284
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {})
285
+ # %full_default_2 : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
286
+ # %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {})
287
+ # %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
288
+ # %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {})
289
+ # %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
290
+ # %iota_11 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
291
+ # %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {})
292
+ # %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {})
293
+ # %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {})
294
+ # %iota_10 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False})
295
+ # %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {})
296
+ # %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {})
297
+ # %iota_8 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False})
298
+ # %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {})
299
+ # %iota_9 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False})
300
+ # %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {})
301
+ # %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {})
302
+ # %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {})
303
+ # %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {})
304
+ # %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {})
305
+ # %full_default_5 : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
306
+ # %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {})
307
+ # %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False})
308
+ # %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {})
309
+ # return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16
310
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', '''
311
+ import triton
312
+ import triton.language as tl
313
+
314
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
315
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
316
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
317
+ triton_helpers.set_driver_to_gpu()
318
+
319
+ @triton_heuristics.persistent_reduction(
320
+ size_hints={'x': 128, 'r0_': 16},
321
+ reduction_hint=ReductionHint.DEFAULT,
322
+ filename=__file__,
323
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
324
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
325
+ )
326
+ @triton.jit
327
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
328
+ xnumel = 128
329
+ r0_numel = 16
330
+ R0_BLOCK: tl.constexpr = 16
331
+ rnumel = r0_numel
332
+ RBLOCK: tl.constexpr = R0_BLOCK
333
+ xoffset = tl.program_id(0) * XBLOCK
334
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
335
+ xmask = xindex < xnumel
336
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
337
+ r0_offset = 0
338
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
339
+ roffset = r0_offset
340
+ rindex = r0_index
341
+ r0_1 = r0_index
342
+ x0 = xindex
343
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
344
+ tmp1 = tl.full([1, 1], 0, tl.int64)
345
+ tmp2 = tmp0 > tmp1
346
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
347
+ tmp4 = tmp0 < tmp3
348
+ tmp5 = tmp2 & tmp4
349
+ tmp6 = tmp5.to(tl.int8)
350
+ tmp7 = tmp6.to(tl.int32)
351
+ tmp8 = r0_1
352
+ tmp9 = tmp8.to(tl.int16)
353
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
354
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
355
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
356
+ tmp14 = tmp0 == tmp3
357
+ tmp15 = tmp14.to(tl.int8)
358
+ tmp16 = tmp15.to(tl.int32)
359
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
360
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
361
+ tmp20 = tmp7.to(tl.int64)
362
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
363
+ tmp23 = tl.where(xmask, tmp21, 0)
364
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
365
+ tmp25 = tmp16.to(tl.int64)
366
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
367
+ tmp28 = tl.where(xmask, tmp26, 0)
368
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
369
+ tmp30 = tmp24.to(tl.int32)
370
+ tmp31 = tmp29.to(tl.int32)
371
+ tmp32 = tmp13.to(tl.int64)
372
+ tmp33 = tmp32.to(tl.int32)
373
+ tmp34 = tmp8 < tmp30
374
+ tmp35 = tl.full([1, 1], 16, tl.int32)
375
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
376
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
377
+ tmp38 = tmp36 + tmp37
378
+ tmp39 = tmp36 < 0
379
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
380
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
381
+ tmp42 = tl.full([1, 1], 1, tl.int32)
382
+ tmp43 = tmp19.to(tl.int64)
383
+ tmp44 = tmp43.to(tl.int32)
384
+ tmp45 = tmp8 < tmp31
385
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
386
+ tmp47 = tmp46 + tmp37
387
+ tmp48 = tmp46 < 0
388
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
389
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
390
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
391
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
392
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
393
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
394
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
395
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
396
+ ''', device_str='cuda')
397
+
398
+
399
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4y/c4yxgotihoxpn6o5xa4jvkcy7shlgnyv44u6dpm5e746f6dwg7oe.py
400
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
401
+ # Source node to ATen node mapping:
402
+ # batched_outputs_3 => clone_4, slice_2
403
+ # col_indices_2 => sort_2
404
+ # num_blocks_in_row_2 => sum_4
405
+ # q_indices => clone_6, convert_element_type_9
406
+ # q_num_blocks => convert_element_type_8
407
+ # transpose => permute_1
408
+ # Graph fragment:
409
+ # %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6" = PlaceHolder[target=buf9]
410
+ # %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=buf11]
411
+ # %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_4]
412
+ # %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {})
413
+ # %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format})
414
+ # %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {})
415
+ # %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True})
416
+ # %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {})
417
+ # %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format})
418
+ # %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {})
419
+ # %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {})
420
+ # return %buf11,%sum_4,%clone_6,%convert_element_type_8
421
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', '''
422
+ import triton
423
+ import triton.language as tl
424
+
425
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
426
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
427
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
428
+ triton_helpers.set_driver_to_gpu()
429
+
430
+ @triton_heuristics.persistent_reduction(
431
+ size_hints={'x': 128, 'r0_': 16},
432
+ reduction_hint=ReductionHint.DEFAULT,
433
+ filename=__file__,
434
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
435
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}}
436
+ )
437
+ @triton.jit
438
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr):
439
+ xnumel = 128
440
+ r0_numel = 16
441
+ R0_BLOCK: tl.constexpr = 16
442
+ rnumel = r0_numel
443
+ RBLOCK: tl.constexpr = R0_BLOCK
444
+ xoffset = tl.program_id(0) * XBLOCK
445
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
446
+ xmask = xindex < xnumel
447
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
448
+ r0_offset = 0
449
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
450
+ roffset = r0_offset
451
+ rindex = r0_index
452
+ r0_2 = r0_index
453
+ x0 = (xindex % 16)
454
+ x1 = xindex // 16
455
+ x3 = xindex
456
+ tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0)
457
+ tmp1 = r0_2
458
+ tmp2 = tmp1.to(tl.int16)
459
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
460
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
461
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
462
+ tmp7 = tmp0.to(tl.int64)
463
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
464
+ tmp10 = tl.where(xmask, tmp8, 0)
465
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
466
+ tmp12 = tmp6.to(tl.int64)
467
+ tmp13 = tmp12.to(tl.int32)
468
+ tmp14 = tmp11.to(tl.int32)
469
+ tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask)
470
+ tl.store(out_ptr3 + (x3), tmp14, xmask)
471
+ ''', device_str='cuda')
472
+
473
+
474
+ async_compile.wait(globals())
475
+ del async_compile
476
+
477
+ class Runner:
478
+ def __init__(self, partitions):
479
+ self.partitions = partitions
480
+
481
+ def recursively_apply_fns(self, fns):
482
+ new_callables = []
483
+ for fn, c in zip(fns, self.partitions):
484
+ new_callables.append(fn(c))
485
+ self.partitions = new_callables
486
+
487
+ def call(self, args):
488
+ arg0_1, = args
489
+ args.clear()
490
+ assert_size_stride(arg0_1, (8, ), (1, ))
491
+ with torch.cuda._DeviceGuard(6):
492
+ torch.cuda.set_device(6)
493
+ buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64)
494
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
495
+ stream6 = get_raw_stream(6)
496
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream6)
497
+ del arg0_1
498
+ buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32)
499
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
500
+ stream6 = get_raw_stream(6)
501
+ triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream6)
502
+ buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32)
503
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
504
+ stream6 = get_raw_stream(6)
505
+ triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream6)
506
+ buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
507
+ buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
508
+ buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
509
+ buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
510
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
511
+ stream6 = get_raw_stream(6)
512
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream6)
513
+ del buf0
514
+ buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
515
+ buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
516
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
517
+ stream6 = get_raw_stream(6)
518
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream6)
519
+ del buf8
520
+ buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
521
+ buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
522
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
523
+ stream6 = get_raw_stream(6)
524
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream6)
525
+ del buf15
526
+ return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, )
527
+
528
+ runner = Runner(partitions=[])
529
+ call = runner.call
530
+ recursively_apply_fns = runner.recursively_apply_fns
531
+
532
+
533
+ def benchmark_compiled_module(times=10, repeat=10):
534
+ from torch._dynamo.testing import rand_strided
535
+ from torch._inductor.utils import print_performance
536
+ arg0_1 = rand_strided((8, ), (1, ), device='cuda:6', dtype=torch.int64)
537
+ fn = lambda: call([arg0_1])
538
+ return print_performance(fn, times=times, repeat=repeat)
539
+
540
+
541
+ if __name__ == "__main__":
542
+ from torch._inductor.wrapper_benchmark import compiled_module_main
543
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/ey/ceyzf3pcewvjtqjk6jiokovxh2sqktcak7dttp7wu3pugjxaoweu.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/ey/f8e6f482f3185b2937177b6d0b6caa60104c3cdb0966b9b98cfda24132197a8c.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"}
SpecForge-ext/cache/compiled_kernels/fg/cfg7sytfzjcof3mvqa6lexwoxlaj3zogf2jn2jbgerew6ytuhqkm.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
34
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
35
+ tmp3 = tl.where(xmask, tmp1, 0)
36
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
37
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
SpecForge-ext/cache/compiled_kernels/fg/cfgdj37atk5pvqz7oags4dv3jc65exjssmxxu3c4srgtfjnh7kgw.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 2
93
+ KV_LEN = ks0
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 16*ks1
131
+ stride_kv_idx_m = ks1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/fg/cfgilsqr4dj7cpcripi7zlobhu3rqxlfddiwwrzuy5xlumnjw5lh.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
34
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
35
+ tmp3 = tl.where(xmask, tmp1, 0)
36
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
37
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
SpecForge-ext/cache/compiled_kernels/fo/cfooe7ht55q5jhejzd3zyb3g5v64cvxjohkxeadllgnjxgiwo52v.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 2048},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x1 = xindex // ks0
24
+ x2 = xindex
25
+ tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last')
26
+ tl.store(out_ptr0 + (x2), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/ft/cftsee2mvtzxgy2wgchwunv4g4rgysco4n3gsokqlal6zoqbmnub.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['7_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py
38
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
39
+ # Source node to ATen node mapping:
40
+ # argmax => argmax
41
+ # Graph fragment:
42
+ # %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:1" = PlaceHolder[target=arg0_1]
43
+ # %argmax : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {})
44
+ # return %argmax
45
+ triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', '''
46
+ import triton
47
+ import triton.language as tl
48
+
49
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
50
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
51
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
52
+ triton_helpers.set_driver_to_gpu()
53
+
54
+ @triton_heuristics.reduction(
55
+ size_hints={'x': 4096, 'r0_': 32768},
56
+ reduction_hint=ReductionHint.INNER,
57
+ filename=__file__,
58
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
59
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}}
60
+ )
61
+ @triton.jit
62
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
63
+ xnumel = 4096
64
+ r0_numel = 32000
65
+ rnumel = r0_numel
66
+ RBLOCK: tl.constexpr = R0_BLOCK
67
+ xoffset = tl.program_id(0) * XBLOCK
68
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
69
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
70
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
71
+ rbase = r0_base
72
+ x0 = xindex
73
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
74
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
75
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
76
+ r0_index = r0_offset + r0_base
77
+ r0_mask = r0_index < r0_numel
78
+ roffset = r0_offset
79
+ rindex = r0_index
80
+ r0_1 = r0_index
81
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
82
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
83
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
84
+ _tmp2, _tmp2_index, tmp1, rindex
85
+ )
86
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
87
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
88
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
89
+ tmp2 = tmp2_idx[:, None]
90
+ tl.store(out_ptr0 + (x0), tmp2, None)
91
+ ''', device_str='cuda')
92
+
93
+
94
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gk/cgkqxvbgd6bawj2pp2icrhzkfuzcptxodfjpshgozv6ysjvxo65g.py
95
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
96
+ # Source node to ATen node mapping:
97
+ # argmax_1 => argmax_1
98
+ # Graph fragment:
99
+ # %arg1_1 : Tensor "f32[2, 2048, 32000][65760000, 32000, 1]cuda:1" = PlaceHolder[target=arg1_1]
100
+ # %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
101
+ # return %argmax_1
102
+ triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', '''
103
+ import triton
104
+ import triton.language as tl
105
+
106
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
107
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
108
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
109
+ triton_helpers.set_driver_to_gpu()
110
+
111
+ @triton_heuristics.reduction(
112
+ size_hints={'x': 4096, 'r0_': 32768},
113
+ reduction_hint=ReductionHint.INNER,
114
+ filename=__file__,
115
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
116
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}}
117
+ )
118
+ @triton.jit
119
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
120
+ xnumel = 4096
121
+ r0_numel = 32000
122
+ rnumel = r0_numel
123
+ RBLOCK: tl.constexpr = R0_BLOCK
124
+ xoffset = tl.program_id(0) * XBLOCK
125
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
126
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
127
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
128
+ rbase = r0_base
129
+ x0 = (xindex % 2048)
130
+ x1 = xindex // 2048
131
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
132
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
133
+ x3 = xindex
134
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
135
+ r0_index = r0_offset + r0_base
136
+ r0_mask = r0_index < r0_numel
137
+ roffset = r0_offset
138
+ rindex = r0_index
139
+ r0_2 = r0_index
140
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0)
141
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
142
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
143
+ _tmp2, _tmp2_index, tmp1, rindex
144
+ )
145
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
146
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
147
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
148
+ tmp2 = tmp2_idx[:, None]
149
+ tl.store(out_ptr0 + (x3), tmp2, None)
150
+ ''', device_str='cuda')
151
+
152
+
153
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py
154
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div]
155
+ # Source node to ATen node mapping:
156
+ # clamp_min => clamp_min
157
+ # eq => eq
158
+ # mul => mul
159
+ # squeeze => squeeze
160
+ # sum_1 => sum_1
161
+ # sum_2 => sum_2
162
+ # truediv => div
163
+ # Graph fragment:
164
+ # %argmax : Tensor "i64[2, 2048][2048, 1]cuda:1" = PlaceHolder[target=argmax]
165
+ # %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:1" = PlaceHolder[target=argmax_1]
166
+ # %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:1" = PlaceHolder[target=arg2_1]
167
+ # %arg3_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:1" = PlaceHolder[target=arg3_1]
168
+ # %sum_1 : Tensor "i64[][]cuda:1" = PlaceHolder[target=sum_1]
169
+ # %sum_2 : Tensor "i64[][]cuda:1" = PlaceHolder[target=sum_2]
170
+ # %eq : Tensor "b8[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {})
171
+ # %squeeze : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {})
172
+ # %mul : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {})
173
+ # %sum_1 : Tensor "i64[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {})
174
+ # %sum_2 : Tensor "i64[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {})
175
+ # %clamp_min : Tensor "f32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {})
176
+ # %div : Tensor "f32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {})
177
+ # return %sum_1,%sum_2,%div
178
+ triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', '''
179
+ import triton
180
+ import triton.language as tl
181
+
182
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
183
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
184
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
185
+ triton_helpers.set_driver_to_gpu()
186
+
187
+ @triton_heuristics.reduction(
188
+ size_hints={'x': 1, 'r0_': 4096},
189
+ reduction_hint=ReductionHint.INNER,
190
+ filename=__file__,
191
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
192
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}}
193
+ )
194
+ @triton.jit
195
+ def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
196
+ xnumel = 1
197
+ r0_numel = 4096
198
+ rnumel = r0_numel
199
+ RBLOCK: tl.constexpr = R0_BLOCK
200
+ xoffset = tl.program_id(0) * XBLOCK
201
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
202
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
203
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
204
+ rbase = r0_base
205
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
206
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
207
+ r0_index = r0_offset + r0_base
208
+ r0_mask = r0_index < r0_numel
209
+ roffset = r0_offset
210
+ rindex = r0_index
211
+ r0_0 = r0_index
212
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
213
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
214
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
215
+ tmp2 = tmp0 == tmp1
216
+ tmp3 = tmp2.to(tl.int64)
217
+ tmp5 = tmp3 * tmp4
218
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
219
+ tmp8 = _tmp7 + tmp6
220
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
221
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
222
+ _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
223
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
224
+ r0_index = r0_offset + r0_base
225
+ r0_mask = r0_index < r0_numel
226
+ roffset = r0_offset
227
+ rindex = r0_index
228
+ r0_0 = r0_index
229
+ tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
230
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
231
+ tmp12 = _tmp11 + tmp10
232
+ _tmp11 = tl.where(r0_mask, tmp12, _tmp11)
233
+ tmp11 = tl.sum(_tmp11, 1)[:, None]
234
+ tmp13 = tmp7.to(tl.float32)
235
+ tmp14 = tmp11.to(tl.float32)
236
+ tmp15 = 1e-06
237
+ tmp16 = triton_helpers.maximum(tmp14, tmp15)
238
+ tmp17 = (tmp13 / tmp16)
239
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None)
240
+ ''', device_str='cuda')
241
+
242
+
243
+ async_compile.wait(globals())
244
+ del async_compile
245
+
246
+ class Runner:
247
+ def __init__(self, partitions):
248
+ self.partitions = partitions
249
+
250
+ def recursively_apply_fns(self, fns):
251
+ new_callables = []
252
+ for fn, c in zip(fns, self.partitions):
253
+ new_callables.append(fn(c))
254
+ self.partitions = new_callables
255
+
256
+ def call(self, args):
257
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
258
+ args.clear()
259
+ assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1))
260
+ assert_size_stride(arg1_1, (2, 2048, 32000), (65760000, 32000, 1))
261
+ assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1))
262
+ assert_size_stride(arg3_1, (2, 2048, 1), (2048, 1, 1))
263
+ with torch.cuda._DeviceGuard(1):
264
+ torch.cuda.set_device(1)
265
+ buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64)
266
+ # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax]
267
+ stream1 = get_raw_stream(1)
268
+ triton_red_fused_argmax_0.run(arg0_1, buf0, 4096, 32000, stream=stream1)
269
+ del arg0_1
270
+ buf1 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64)
271
+ # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax]
272
+ stream1 = get_raw_stream(1)
273
+ triton_red_fused_argmax_1.run(arg1_1, buf1, 4096, 32000, stream=stream1)
274
+ del arg1_1
275
+ buf4 = empty_strided_cuda((), (), torch.float32)
276
+ # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div]
277
+ stream1 = get_raw_stream(1)
278
+ triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, arg3_1, buf4, 1, 4096, stream=stream1)
279
+ del arg2_1
280
+ del arg3_1
281
+ del buf0
282
+ del buf1
283
+ return (buf4, )
284
+
285
+ runner = Runner(partitions=[])
286
+ call = runner.call
287
+ recursively_apply_fns = runner.recursively_apply_fns
288
+
289
+
290
+ def benchmark_compiled_module(times=10, repeat=10):
291
+ from torch._dynamo.testing import rand_strided
292
+ from torch._inductor.utils import print_performance
293
+ arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:1', dtype=torch.bfloat16)
294
+ arg1_1 = rand_strided((2, 2048, 32000), (65760000, 32000, 1), device='cuda:1', dtype=torch.float32)
295
+ arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.int64)
296
+ arg3_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.int64)
297
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
298
+ return print_performance(fn, times=times, repeat=repeat)
299
+
300
+
301
+ if __name__ == "__main__":
302
+ from torch._inductor.wrapper_benchmark import compiled_module_main
303
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/g3/cg3kczutozttzr55b4vjq62nto7vv2qnqb553mhae4gtgepz7vkj.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 16384
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = ((xindex // ks0) % 16)
29
+ x2 = xindex // ks2
30
+ _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
31
+ x5 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_3 = (r0_index % 128)
38
+ r0_4 = r0_index // 128
39
+ tmp0 = r0_3 + 128*x0
40
+ tmp1 = ks1
41
+ tmp2 = tmp0 < tmp1
42
+ tmp3 = r0_4 + 128*x1
43
+ tmp4 = r0_3 + 128*x0
44
+ tmp5 = tmp3 >= tmp4
45
+ tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
46
+ tmp7 = tmp4 < tmp6
47
+ tmp8 = tmp3 < tmp6
48
+ tmp9 = tmp7 & tmp8
49
+ tmp10 = tmp5 & tmp9
50
+ tmp11 = tl.full([1, 1], False, tl.int1)
51
+ tmp12 = tmp11 | tmp10
52
+ tmp13 = tl.full([1, 1], 2048, tl.int64)
53
+ tmp14 = tmp4 >= tmp13
54
+ tmp15 = ((r0_3 + 128*x0) % 2048)
55
+ tmp16 = tmp15 < tmp6
56
+ tmp17 = tmp14 & tmp16
57
+ tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
58
+ tmp19 = (tmp18 % tmp13)
59
+ tmp20 = tl.full([1, 1], 0, tl.int32)
60
+ tmp21 = tmp19 != tmp20
61
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
62
+ tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
63
+ tmp24 = tmp22 != tmp23
64
+ tmp25 = tmp21 & tmp24
65
+ tmp26 = tmp19 + tmp13
66
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
67
+ tmp28 = tl.full([1, 1], 0, tl.int64)
68
+ tmp29 = tmp27 == tmp28
69
+ tmp30 = tmp17 & tmp29
70
+ tmp31 = tmp12 | tmp30
71
+ tmp32 = tl.full(tmp31.shape, False, tmp31.dtype)
72
+ tmp33 = tl.where(tmp2, tmp31, tmp32)
73
+ tmp34 = tmp33.to(tl.int64)
74
+ tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK])
75
+ tmp37 = _tmp36 + tmp35
76
+ _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36)
77
+ tmp36 = tl.sum(_tmp36, 1)[:, None]
78
+ tmp38 = tl.full([1, 1], 0, tl.int64)
79
+ tmp39 = tmp36 > tmp38
80
+ tmp40 = tl.full([1, 1], 16384, tl.int64)
81
+ tmp41 = tmp36 < tmp40
82
+ tmp42 = tmp39 & tmp41
83
+ tmp43 = tmp42.to(tl.int8)
84
+ tmp44 = tmp43.to(tl.int32)
85
+ tmp45 = tmp36 == tmp40
86
+ tmp46 = tmp45.to(tl.int8)
87
+ tmp47 = tmp46.to(tl.int32)
88
+ tl.store(out_ptr1 + (x5), tmp44, xmask)
89
+ tl.store(out_ptr2 + (x5), tmp47, xmask)
SpecForge-ext/cache/compiled_kernels/hq/chqc3is7lze3bdohf7qrowyfetyhjquhgfsobrnoq7hbrmp6ohdx.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['2_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py
38
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # hidden_states => convert_element_type
41
+ # hidden_states_1 => mul_16
42
+ # to_1 => convert_element_type_1
43
+ # Graph fragment:
44
+ # %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=tangents_1]
45
+ # %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=primals_4]
46
+ # %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6" = PlaceHolder[target=rsqrt]
47
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
48
+ # %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
49
+ # %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
50
+ # %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
51
+ # %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {})
52
+ # return %buf0
53
+ triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+ triton_helpers.set_driver_to_gpu()
61
+
62
+ @triton_heuristics.reduction(
63
+ size_hints={'x': 131072, 'r0_': 128},
64
+ reduction_hint=ReductionHint.OUTER,
65
+ filename=__file__,
66
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
67
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
68
+ )
69
+ @triton.jit
70
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x1 = xindex // ks0
79
+ x0 = (xindex % ks0)
80
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
81
+ x3 = xindex
82
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
83
+ r0_index = r0_offset + r0_base
84
+ r0_mask = r0_index < r0_numel
85
+ roffset = r0_offset
86
+ rindex = r0_index
87
+ r0_2 = r0_index
88
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
89
+ tmp1 = ks1*ks2
90
+ tmp2 = tmp0 < tmp1
91
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
92
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
93
+ tmp5 = tmp4.to(tl.float32)
94
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
95
+ tmp7 = tmp5 * tmp6
96
+ tmp8 = tmp7.to(tl.float32)
97
+ tmp9 = tmp3 * tmp8
98
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
99
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
100
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
101
+ tmp14 = _tmp13 + tmp12
102
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
103
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
104
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
105
+ ''', device_str='cuda')
106
+
107
+
108
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/yf/cyft7sialepriw6eujulaxpi57qlrafkmp4k2kjwzw4noh23ddz6.py
109
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
110
+ # Source node to ATen node mapping:
111
+ # hidden_states => convert_element_type
112
+ # hidden_states_1 => mul_16
113
+ # to_1 => convert_element_type_1
114
+ # Graph fragment:
115
+ # %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:6" = PlaceHolder[target=buf0]
116
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
117
+ # %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
118
+ # %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
119
+ # %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
120
+ # %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {})
121
+ # return %sum_1
122
+ triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', '''
123
+ import triton
124
+ import triton.language as tl
125
+
126
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
127
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
128
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
129
+ triton_helpers.set_driver_to_gpu()
130
+
131
+ @triton_heuristics.persistent_reduction(
132
+ size_hints={'x': 4096, 'r0_': 32},
133
+ reduction_hint=ReductionHint.OUTER,
134
+ filename=__file__,
135
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
137
+ )
138
+ @triton.jit
139
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
140
+ r0_numel = 32
141
+ R0_BLOCK: tl.constexpr = 32
142
+ rnumel = r0_numel
143
+ RBLOCK: tl.constexpr = R0_BLOCK
144
+ xoffset = tl.program_id(0) * XBLOCK
145
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
146
+ xmask = xindex < xnumel
147
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
148
+ r0_offset = 0
149
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
150
+ roffset = r0_offset
151
+ rindex = r0_index
152
+ r0_1 = r0_index
153
+ x0 = xindex
154
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
155
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
156
+ tmp3 = tl.where(xmask, tmp1, 0)
157
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
158
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
159
+ ''', device_str='cuda')
160
+
161
+
162
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gn/cgnmjxikvi5ulcyj3uozif3le5hd26kw2kjhkcbhupqgudqi3bwn.py
163
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add]
164
+ # Source node to ATen node mapping:
165
+ # hidden_states => convert_element_type
166
+ # Graph fragment:
167
+ # %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=tangents_1]
168
+ # %primals_7 : Tensor "bf16[s33][1]cuda:6" = PlaceHolder[target=primals_7]
169
+ # %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=primals_4]
170
+ # %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6" = PlaceHolder[target=rsqrt]
171
+ # %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:6" = PlaceHolder[target=sum_2]
172
+ # %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {})
173
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
174
+ # %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {})
175
+ # %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {})
176
+ # %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {})
177
+ # %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {})
178
+ # %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {})
179
+ # %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {})
180
+ # %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {})
181
+ # %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {})
182
+ # %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {})
183
+ # %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {})
184
+ # %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {})
185
+ # %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {})
186
+ # %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {})
187
+ # %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {})
188
+ # return %sum_2,%convert_element_type_3
189
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', '''
190
+ import triton
191
+ import triton.language as tl
192
+
193
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
194
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
195
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
196
+ triton_helpers.set_driver_to_gpu()
197
+
198
+ @triton_heuristics.reduction(
199
+ size_hints={'x': 4096, 'r0_': 4096},
200
+ reduction_hint=ReductionHint.INNER,
201
+ filename=__file__,
202
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
203
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
204
+ )
205
+ @triton.jit
206
+ def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
207
+ rnumel = r0_numel
208
+ RBLOCK: tl.constexpr = R0_BLOCK
209
+ xoffset = tl.program_id(0) * XBLOCK
210
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
211
+ xmask = xindex < xnumel
212
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
213
+ rbase = r0_base
214
+ x0 = xindex
215
+ _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
216
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
217
+ r0_index = r0_offset + r0_base
218
+ r0_mask = r0_index < r0_numel
219
+ roffset = r0_offset
220
+ rindex = r0_index
221
+ r0_1 = r0_index
222
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
223
+ tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
224
+ tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
225
+ tmp2 = tmp0 * tmp1
226
+ tmp3 = tmp2.to(tl.float32)
227
+ tmp5 = tmp4.to(tl.float32)
228
+ tmp6 = tmp3 * tmp5
229
+ tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
230
+ tmp9 = _tmp8 + tmp7
231
+ _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8)
232
+ tmp8 = tl.sum(_tmp8, 1)[:, None]
233
+ tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
234
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
235
+ r0_index = r0_offset + r0_base
236
+ r0_mask = r0_index < r0_numel
237
+ roffset = r0_offset
238
+ rindex = r0_index
239
+ r0_1 = r0_index
240
+ tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
241
+ tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
242
+ tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
243
+ tmp12 = tmp10 * tmp11
244
+ tmp13 = tmp12.to(tl.float32)
245
+ tmp15 = tmp13 * tmp14
246
+ tmp16 = -0.5
247
+ tmp17 = tmp8 * tmp16
248
+ tmp18 = tmp14 * tmp14
249
+ tmp19 = tmp18 * tmp14
250
+ tmp20 = tmp17 * tmp19
251
+ tmp21 = ks0
252
+ tmp22 = tmp21.to(tl.float32)
253
+ tmp23 = (tmp20 / tmp22)
254
+ tmp25 = tmp24.to(tl.float32)
255
+ tmp26 = 2.0
256
+ tmp27 = tmp25 * tmp26
257
+ tmp28 = tmp23 * tmp27
258
+ tmp29 = tmp15 + tmp28
259
+ tmp30 = tmp29.to(tl.float32)
260
+ tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask)
261
+ ''', device_str='cuda')
262
+
263
+
264
+ async_compile.wait(globals())
265
+ del async_compile
266
+
267
+ class Runner:
268
+ def __init__(self, partitions):
269
+ self.partitions = partitions
270
+
271
+ def recursively_apply_fns(self, fns):
272
+ new_callables = []
273
+ for fn, c in zip(fns, self.partitions):
274
+ new_callables.append(fn(c))
275
+ self.partitions = new_callables
276
+
277
+ def call(self, args):
278
+ primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args
279
+ args.clear()
280
+ s47 = primals_1
281
+ s87 = primals_2
282
+ s33 = primals_3
283
+ s82 = primals_6
284
+ assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1))
285
+ assert_size_stride(primals_7, (s33, ), (1, ))
286
+ assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1))
287
+ assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1))
288
+ with torch.cuda._DeviceGuard(6):
289
+ torch.cuda.set_device(6)
290
+ buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32)
291
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
292
+ triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33
293
+ triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32
294
+ stream6 = get_raw_stream(6)
295
+ triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream6)
296
+ buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16)
297
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
298
+ stream6 = get_raw_stream(6)
299
+ triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream6)
300
+ del buf0
301
+ buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16)
302
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add]
303
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87
304
+ stream6 = get_raw_stream(6)
305
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream6)
306
+ del primals_4
307
+ del primals_7
308
+ del rsqrt
309
+ del tangents_1
310
+ return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), )
311
+
312
+ runner = Runner(partitions=[])
313
+ call = runner.call
314
+ recursively_apply_fns = runner.recursively_apply_fns
315
+
316
+
317
+ def benchmark_compiled_module(times=10, repeat=10):
318
+ from torch._dynamo.testing import rand_strided
319
+ from torch._inductor.utils import print_performance
320
+ primals_1 = 2
321
+ primals_2 = 2048
322
+ primals_3 = 4096
323
+ primals_6 = 840433664
324
+ primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
325
+ primals_7 = rand_strided((4096, ), (1, ), device='cuda:6', dtype=torch.bfloat16)
326
+ rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.float32)
327
+ tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
328
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1])
329
+ return print_performance(fn, times=times, repeat=repeat)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ from torch._inductor.wrapper_benchmark import compiled_module_main
334
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/hq/chqstdcrwlggtj2cbkjjgtxib54f5qfcipeqs3k27hifudgguv7t.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 8
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = ks8
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = ks0
596
+ KV_LEN = ks1
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = ks8
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/hv/chvj5h3adlnuxifatrhlirixthstwv5pzbxvuapjby5cz2npck63.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 2048, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 16384
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x1 = ((xindex // ks0) % ks1)
28
+ x0 = (xindex % ks0)
29
+ x2 = xindex // ks4
30
+ _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
31
+ x5 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_4 = r0_index // 128
38
+ r0_3 = (r0_index % 128)
39
+ tmp0 = r0_4 + 128*x1
40
+ tmp1 = ks2
41
+ tmp2 = tmp0 < tmp1
42
+ tmp3 = r0_3 + 128*x0
43
+ tmp4 = ks3
44
+ tmp5 = tmp3 < tmp4
45
+ tmp6 = tmp2 & tmp5
46
+ tmp7 = r0_4 + 128*x1
47
+ tmp8 = r0_3 + 128*x0
48
+ tmp9 = tmp7 >= tmp8
49
+ tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0)
50
+ tmp11 = tmp8 < tmp10
51
+ tmp12 = tmp7 < tmp10
52
+ tmp13 = tmp11 & tmp12
53
+ tmp14 = tmp9 & tmp13
54
+ tmp15 = tl.full([1, 1], False, tl.int1)
55
+ tmp16 = tmp15 | tmp14
56
+ tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK])
57
+ tmp18 = tmp8 >= tmp17
58
+ tmp19 = (tmp8 % tmp17)
59
+ tmp20 = tl.full([1, 1], 0, tl.int32)
60
+ tmp21 = tmp19 != tmp20
61
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
62
+ tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0
63
+ tmp24 = tmp22 != tmp23
64
+ tmp25 = tmp21 & tmp24
65
+ tmp26 = tmp19 + tmp17
66
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
67
+ tmp28 = tmp27 < tmp10
68
+ tmp29 = tmp18 & tmp28
69
+ tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
70
+ tmp31 = (tmp30 % tmp17)
71
+ tmp32 = tmp31 != tmp20
72
+ tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0
73
+ tmp34 = tmp33 != tmp23
74
+ tmp35 = tmp32 & tmp34
75
+ tmp36 = tmp31 + tmp17
76
+ tmp37 = tl.where(tmp35, tmp36, tmp31)
77
+ tmp38 = tl.full([1, 1], 0, tl.int64)
78
+ tmp39 = tmp37 == tmp38
79
+ tmp40 = tmp29 & tmp39
80
+ tmp41 = tmp16 | tmp40
81
+ tmp42 = tl.full(tmp41.shape, False, tmp41.dtype)
82
+ tmp43 = tl.where(tmp6, tmp41, tmp42)
83
+ tmp44 = tmp43.to(tl.int64)
84
+ tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK])
85
+ tmp47 = _tmp46 + tmp45
86
+ _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46)
87
+ tmp46 = tl.sum(_tmp46, 1)[:, None]
88
+ tmp48 = tl.full([1, 1], 0, tl.int64)
89
+ tmp49 = tmp46 > tmp48
90
+ tmp50 = tl.full([1, 1], 16384, tl.int64)
91
+ tmp51 = tmp46 < tmp50
92
+ tmp52 = tmp49 & tmp51
93
+ tmp53 = tmp52.to(tl.int8)
94
+ tmp54 = tmp53.to(tl.int32)
95
+ tmp55 = tmp46 == tmp50
96
+ tmp56 = tmp55.to(tl.int8)
97
+ tmp57 = tmp56.to(tl.int32)
98
+ tl.store(out_ptr1 + (x5), tmp54, xmask)
99
+ tl.store(out_ptr2 + (x5), tmp57, xmask)
SpecForge-ext/cache/compiled_kernels/ks/36cfdc5c4318d8e35940f3471fa9a8cde8092c3294a90679819920b4db6ea3bb.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "4UWYNBR3KPWQGNAZ5LIIRE7YAZWTQP4CP3JS6GOSLWYDF5K7WTAA"}
SpecForge-ext/cache/compiled_kernels/ks/cksdatp7sjl5kfr5pxvwrbjelhvz35c35rvym5wgbvhrovwd5isa.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
29
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
37
+ tmp1 = tmp0.to(tl.float32)
38
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
39
+
40
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
41
+ _tmp3_max, _tmp3_sum, tmp2, False
42
+ )
43
+
44
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
45
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
46
+
47
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
48
+ _tmp3_max, _tmp3_sum, 1, False)
49
+ tmp3 = tmp3[:, None]
50
+ tmp4 = tmp4[:, None]
51
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
52
+ r0_index = r0_offset + r0_base
53
+ r0_mask = r0_index < r0_numel
54
+ roffset = r0_offset
55
+ rindex = r0_index
56
+ r0_1 = r0_index
57
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp6 = tmp5.to(tl.float32)
59
+ tmp7 = tmp6 - tmp3
60
+ tmp8 = libdevice.exp(tmp7)
61
+ tmp9 = (tmp8 / tmp4)
62
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/ks/ckske6cm4vgoewu6hpzmhdk7yxnddtnqlrbts7nwodsrty3grim2.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4096},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 2176
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.full([1], 0, tl.int32)
25
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/kx/ckxgrh6l45wgzd3gv6uy3i3z4hrfyct6es6sh2fdnsi6q4hicyjs.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['10_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/py/cpyon4zgaupgqfwtaeshxummq5taahi4k54ubix2xgrrupxyugiq.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul_2
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:5" = PlaceHolder[target=arg1_1]
47
+ # %argmax : Tensor "i64[2, s14][s14, 1]cuda:5" = PlaceHolder[target=argmax]
48
+ # %arg2_1 : Tensor "b8[151936][1]cuda:5" = PlaceHolder[target=arg2_1]
49
+ # %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:5" = PlaceHolder[target=arg3_1]
50
+ # %argmax : Tensor "i64[2, s14][s14, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[2, s14][s14, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {})
55
+ # return %argmax,%mul_2
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 4096, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ r0_numel = 151936
75
+ rnumel = r0_numel
76
+ RBLOCK: tl.constexpr = R0_BLOCK
77
+ xoffset = tl.program_id(0) * XBLOCK
78
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
79
+ xmask = xindex < xnumel
80
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
81
+ rbase = r0_base
82
+ x0 = xindex
83
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
84
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
85
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
86
+ r0_index = r0_offset + r0_base
87
+ r0_mask = r0_index < r0_numel
88
+ roffset = r0_offset
89
+ rindex = r0_index
90
+ r0_1 = r0_index
91
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
92
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
93
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
94
+ _tmp2, _tmp2_index, tmp1, rindex
95
+ )
96
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
97
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
98
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
99
+ tmp2 = tmp2_idx[:, None]
100
+ tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
101
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
102
+ tmp4 = tmp2 + tmp3
103
+ tmp5 = tmp2 < 0
104
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
105
+ tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936")
106
+ tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1)
107
+ tmp9 = tmp8.to(tl.int32)
108
+ tmp10 = tmp9.to(tl.int64)
109
+ tmp12 = tmp10 * tmp11
110
+ tl.debug_barrier()
111
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
112
+ ''', device_str='cuda')
113
+
114
+
115
+ async_compile.wait(globals())
116
+ del async_compile
117
+
118
+ class Runner:
119
+ def __init__(self, partitions):
120
+ self.partitions = partitions
121
+
122
+ def recursively_apply_fns(self, fns):
123
+ new_callables = []
124
+ for fn, c in zip(fns, self.partitions):
125
+ new_callables.append(fn(c))
126
+ self.partitions = new_callables
127
+
128
+ def call(self, args):
129
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
130
+ args.clear()
131
+ s24 = arg0_1
132
+ arg1_1_size = arg1_1.size()
133
+ s14 = arg1_1_size[1]
134
+ assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1))
135
+ assert_size_stride(arg2_1, (151936, ), (1, ))
136
+ assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1))
137
+ with torch.cuda._DeviceGuard(5):
138
+ torch.cuda.set_device(5)
139
+ buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64)
140
+ buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse
141
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
142
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14
143
+ stream5 = get_raw_stream(5)
144
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream5)
145
+ del arg1_1
146
+ del arg2_1
147
+ del arg3_1
148
+ return (buf1, )
149
+
150
+ runner = Runner(partitions=[])
151
+ call = runner.call
152
+ recursively_apply_fns = runner.recursively_apply_fns
153
+
154
+
155
+ def benchmark_compiled_module(times=10, repeat=10):
156
+ from torch._dynamo.testing import rand_strided
157
+ from torch._inductor.utils import print_performance
158
+ arg0_1 = 1569
159
+ arg1_1 = rand_strided((2, 1569, 151936), (238387584, 151936, 1), device='cuda:5', dtype=torch.bfloat16)
160
+ arg2_1 = rand_strided((151936, ), (1, ), device='cuda:5', dtype=torch.bool)
161
+ arg3_1 = rand_strided((2, 1569, 1), (1569, 1, 1), device='cuda:5', dtype=torch.int64)
162
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
163
+ return print_performance(fn, times=times, repeat=repeat)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ from torch._inductor.wrapper_benchmark import compiled_module_main
168
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/kx/ckxtdzhg3azhdxeooy2uushwzka4sz2hzjpq5dulk2g2jjweqr6b.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 2
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = ks5
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/l4/858a6c2e50b765fa4386efe0007977eb588741281d2c492d383f481ceaa46b11.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"}
SpecForge-ext/cache/compiled_kernels/l4/cl45ilp34erze7maypgnzjiaafh3lmzk67erw2irtjg7fhwhyggv.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)