Lekr0 commited on
Commit
842e4fb
·
verified ·
1 Parent(s): 62dca4c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. progress/SpecForge/cache/compiled_kernels/22/2cc448de6eeb5f6db19c2adf4fa08d257ac050432532b15d4b1b5f447657bb74.best_config +1 -0
  2. progress/SpecForge/cache/compiled_kernels/22/c225p5q54jhc2rfoccuzlgejscvq2in5jzxlzcilu44cplhbfreo.py +28 -0
  3. progress/SpecForge/cache/compiled_kernels/22/c22yrjhqxirm4hkhfziohi6cfktos4modosbdflw25tw7a5d5gy7.py +799 -0
  4. progress/SpecForge/cache/compiled_kernels/2h/c2hwir33itr7umd7f5wx6cpaiwom2wrrbqi5cyznolnljriqz7pk.py +1028 -0
  5. progress/SpecForge/cache/compiled_kernels/2m/50bb8a1cf8ca03a72155f9f84fae7c10cfe881f9f7d97e1e3f94ec85776c3639.best_config +1 -0
  6. progress/SpecForge/cache/compiled_kernels/2m/c2mzrxons6jvrj3mv77db5xyv5m4z73mego5v77sl66o6wuh5dbk.py +28 -0
  7. progress/SpecForge/cache/compiled_kernels/2p/c2pjjgigeh2ro4r74dzlvlf7os5rhnmyche5rzwawor6zxb6rvk2.py +1018 -0
  8. progress/SpecForge/cache/compiled_kernels/2x/c2xecscuz5jhvznv7jn4k545b7kcexuko5lz3em6woeo7u2ftonz.py +58 -0
  9. progress/SpecForge/cache/compiled_kernels/2x/c2ximikyisa7xxnki36flzcsdr4ziwruq7ujf3zymsuxon5pqv57.py +707 -0
  10. progress/SpecForge/cache/compiled_kernels/2x/dde0479cca0d878e6e0800ec13f7c80962354e837542bfc5f11f7b49306d323e.best_config +1 -0
  11. progress/SpecForge/cache/compiled_kernels/3h/6ee97c795357f97e7127237e15db9bd5fb14510b837eeb5094115cfaa1802d32.best_config +1 -0
  12. progress/SpecForge/cache/compiled_kernels/3h/c3h3fb5vqykgr7s3powfrnsc5alooplbijdgjizqo3xq5psavrvz.py +28 -0
  13. progress/SpecForge/cache/compiled_kernels/3h/c3hro2ygwh2ixqhmbrrdsjq6biaehv6lm5cbeo6yhlo6ssqkwpha.py +799 -0
  14. progress/SpecForge/cache/compiled_kernels/3p/c3pdrhexk4rwol7f5l5vh7n543dj6piq6gw5k66g2p4vlyhopnop.py +58 -0
  15. progress/SpecForge/cache/compiled_kernels/3p/d4e91f4bc49d9cfc59a03caa3a2e04988f99c358762e8d23eed306dbbe3eae25.best_config +1 -0
  16. progress/SpecForge/cache/compiled_kernels/3s/c3spq2k2yeawxvgwl4dczrad6qwkidiiyxz5xwsucqivwlx625g7.py +534 -0
  17. progress/SpecForge/cache/compiled_kernels/3u/c3umapah7vcozhvfk5uovlssor7v533y4crphqgd677nuoizbpvj.py +799 -0
  18. progress/SpecForge/cache/compiled_kernels/42/c424arzgjg22xrcyl4orsbfthh3vxddttchjdd7yswdd5pdxdhtv.py +1019 -0
  19. progress/SpecForge/cache/compiled_kernels/44/3728d77fd47f8b1056ec8670d5b1bd262db03ae9994292fee6203d32e3d9cd03.best_config +1 -0
  20. progress/SpecForge/cache/compiled_kernels/44/c44m5klhlzg7nfvzfelnbb3hjh2jwzh2e5yyk3vtcvhyw6rbnjo6.py +54 -0
  21. progress/SpecForge/cache/compiled_kernels/4d/c4d7fh2egdfps7aogbncwlp3ihfwtff243bbobq7vrxj2m2grl64.py +51 -0
  22. progress/SpecForge/cache/compiled_kernels/4d/fd68b3c1a3fd19883dc58697393b6044e6217afda9ea11f84bd620545197dd6b.best_config +1 -0
  23. progress/SpecForge/cache/compiled_kernels/4h/c4h32peoig2erjdxibxrq3sbpm533ci3z57ntqjhdemzxp2rhysl.py +799 -0
  24. progress/SpecForge/cache/compiled_kernels/4k/c4korm4huj2wookuw6gikboxrsp3m5yt45c7fxucyujswm5fgb3u.py +534 -0
  25. progress/SpecForge/cache/compiled_kernels/4x/31f9d1ee4882fe2005f02592ea2d9f20a1835b42c5baefd7795e8640f97fdc16.best_config +1 -0
  26. progress/SpecForge/cache/compiled_kernels/4x/c4xjhgyzut6anhrjeinspoinohfxvyl6skr4gd3vfrscrvsevmya.py +28 -0
  27. progress/SpecForge/cache/compiled_kernels/5b/c5blvz5sxoj2veuexokuub2zm2pg4l2nqbbny4rr2jhsiiyw6njy.py +534 -0
  28. progress/SpecForge/cache/compiled_kernels/5g/c5g7nnbi3zupsx7kdee2ed6g2fgrtd2jxyggsjpckfg5p7rps4qm.py +534 -0
  29. progress/SpecForge/cache/compiled_kernels/5j/3cc65a0fdb544c73efb7240355b77da3f1ab394b46f272fa923c368e6cc63c34.best_config +1 -0
  30. progress/SpecForge/cache/compiled_kernels/5j/c5j7yk5hlaaxs42qwjlmoczwtoukaw2dio2o6p7qfekdy5upikyv.py +54 -0
  31. progress/SpecForge/cache/compiled_kernels/5w/c5wutjfcact264ykgcamj2asvz4eqe3ygz47upjgib2qw5rnnihu.py +1019 -0
  32. progress/SpecForge/cache/compiled_kernels/5z/c5zh2j5k5rlsr5zd4tfbvoplpwmbtizbrldjq2hw4nndmjztlcuy.py +879 -0
  33. progress/SpecForge/cache/compiled_kernels/66/c66vzxeqjq4tywx6ezsscs3u3rb6yxac26rzmkwrlzc3kmkcnhlf.py +799 -0
  34. progress/SpecForge/cache/compiled_kernels/6b/c6b364ingwlftc5camjox4wdd5z5l4famigf6ojv4cyji4ju37fy.py +879 -0
  35. progress/SpecForge/cache/compiled_kernels/6g/826e9651ad1f65a7d666097ac7518bb4b4d3dee1984132523b860dd02b66fff1.best_config +1 -0
  36. progress/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py +58 -0
  37. progress/SpecForge/cache/compiled_kernels/6o/c6ovzyfo6vkdwwzou6dtdvw7qjf65ifmzpcoltl2nx2xuluryjcy.py +48 -0
  38. progress/SpecForge/cache/compiled_kernels/6o/df004f0eefe2693a59f2bae06581f78ab07b5ce2ec28936911ef13f1152e2ec9.best_config +1 -0
  39. progress/SpecForge/cache/compiled_kernels/6u/c6ulsdn73forgosxqs5bes2cerczsehypg7jodd4snit3gcqp6el.py +27 -0
  40. progress/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py +534 -0
  41. progress/SpecForge/cache/compiled_kernels/6u/e5329724392dcdd68d88f082f57e8929de539c9aa187c3a314edefdc595437d5.best_config +1 -0
  42. progress/SpecForge/cache/compiled_kernels/7a/aee791ee3934869dfa55caee4270f116dc979737c2bbcce40af5d5394ccc9ac8.best_config +1 -0
  43. progress/SpecForge/cache/compiled_kernels/7a/c7a2brsshxp6zz4foe62t5ivwbd2dwr6ytjbhxp22vq2evdotx5z.py +28 -0
  44. progress/SpecForge/cache/compiled_kernels/7a/c7adkdqab5cvqxxnwnn5au23gorh2eg33cfxneh7bzb7untnuvpw.py +534 -0
  45. progress/SpecForge/cache/compiled_kernels/7i/c7ijsdt7wst5xe64qslxvdevuoxlscozrh6zigwqavztuz3rptdj.py +58 -0
  46. progress/SpecForge/cache/compiled_kernels/7i/e1454135615b9b6420e5ef4fe0804f1b1346f398b5afb34dbe8085f0e900c8aa.best_config +1 -0
  47. progress/SpecForge/cache/compiled_kernels/7n/c7n4jk5r4lsbq62vtrxzouvawlecnbfhy3owedw4ewuid7d56bjs.py +582 -0
  48. progress/SpecForge/cache/compiled_kernels/a3/ca3omjumwqpxxjrgphxuxva3yanssfkbnvrp3buqomyudb2eg4nc.py +534 -0
  49. progress/SpecForge/cache/compiled_kernels/ad/4a5ce6c582fc1ef37d4cf3003d603da533264d4c59e6e0cb171d0b7490f32260.best_config +1 -0
  50. progress/SpecForge/cache/compiled_kernels/ad/b7e210dbfa93430d766b2fdeddfba773a52faf0eca21a68632a8853e9c3ecaf4.best_config +1 -0
progress/SpecForge/cache/compiled_kernels/22/2cc448de6eeb5f6db19c2adf4fa08d257ac050432532b15d4b1b5f447657bb74.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"}
progress/SpecForge/cache/compiled_kernels/22/c225p5q54jhc2rfoccuzlgejscvq2in5jzxlzcilu44cplhbfreo.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 32768},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x0 = (xindex % ks0)
24
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
25
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
26
+ tmp1 = 0.6931471805599453
27
+ tmp2 = tmp0 * tmp1
28
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/22/c22yrjhqxirm4hkhfziohi6cfktos4modosbdflw25tw7a5d5gy7.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 2555904, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 638976, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 638976, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 2555904, 79872, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 2555904, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 638976, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 624
106
+ ZKV = 1
107
+ KV_LEN = 624
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 5
148
+ stride_kv_idx_h = 25
149
+ stride_kv_idx_m = 5
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 5
245
+ stride_q_idx_h = 25
246
+ stride_q_idx_n = 5
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 79872*off_hkv + 638976*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 624
385
+ KV_LEN = 624
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = 624
578
+ KV_LEN = 624
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/2h/c2hwir33itr7umd7f5wx6cpaiwom2wrrbqi5cyznolnljriqz7pk.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['1_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/il/cilouthst7wssvb523k44wgbbccu5i25bgzvyf7k74iapy3kurts.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0]
44
+ # %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=tangents_2]
45
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
46
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
47
+ # return %buf0,%buf1
48
+ triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', '''
49
+ import triton
50
+ import triton.language as tl
51
+
52
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
53
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
54
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
55
+ triton_helpers.set_driver_to_gpu()
56
+
57
+ @triton_heuristics.reduction(
58
+ size_hints={'x': 65536, 'r0_': 128},
59
+ reduction_hint=ReductionHint.DEFAULT,
60
+ filename=__file__,
61
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
62
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
63
+ )
64
+ @triton.jit
65
+ def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
66
+ r0_numel = 128
67
+ rnumel = r0_numel
68
+ RBLOCK: tl.constexpr = R0_BLOCK
69
+ xoffset = tl.program_id(0) * XBLOCK
70
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
71
+ xmask = xindex < xnumel
72
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
73
+ rbase = r0_base
74
+ x0 = (xindex % ks0)
75
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
76
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
77
+ x3 = xindex
78
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
79
+ r0_index = r0_offset + r0_base
80
+ r0_mask = r0_index < r0_numel
81
+ roffset = r0_offset
82
+ rindex = r0_index
83
+ r0_2 = r0_index
84
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
85
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp2 = tmp0 * tmp1
87
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
88
+ tmp5 = _tmp4 + tmp3
89
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
90
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
91
+ tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp8 = 0.6931471805599453
94
+ tmp9 = tmp7 * tmp8
95
+ tmp10 = 1.4426950408889634
96
+ tmp11 = tmp9 * tmp10
97
+ tmp12 = tmp6 - tmp11
98
+ tl.store(out_ptr1 + (x3), tmp12, xmask)
99
+ ''', device_str='cuda')
100
+
101
+
102
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/2d/c2dfd76xaqhs43mudrn3koh54vgu4ljjh5tphjwwwqhxcdloa7kl.py
103
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
104
+ # Source node to ATen node mapping:
105
+ # Graph fragment:
106
+ # %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2]
107
+ # %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4]
108
+ # %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6]
109
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1]
110
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1]
111
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1]
112
+ # %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3]
113
+ # %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5]
114
+ # %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:6" = PlaceHolder[target=primals_13]
115
+ # %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:6" = PlaceHolder[target=primals_9]
116
+ # %primals_16 : Tensor "i32[1, 1, 13][13, 13, 1]cuda:6" = PlaceHolder[target=primals_16]
117
+ # %primals_17 : Tensor "i32[1, 1, 13, 13][169, 169, 13, 1]cuda:6" = PlaceHolder[target=primals_17]
118
+ # %primals_14 : Tensor "i32[1, 1, 13][13, 13, 1]cuda:6" = PlaceHolder[target=primals_14]
119
+ # %primals_15 : Tensor "i32[1, 1, 13, 13][169, 169, 13, 1]cuda:6" = PlaceHolder[target=primals_15]
120
+ # %primals_18 : Tensor "i32[1, 1, 13][13, 13, 1]cuda:6" = PlaceHolder[target=primals_18]
121
+ # %primals_19 : Tensor "i32[1, 1, 13, 13][169, 169, 13, 1]cuda:6" = PlaceHolder[target=primals_19]
122
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
123
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
124
+ # return %getitem_4
125
+ triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
126
+ import triton
127
+ import triton.language as tl
128
+
129
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
130
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
131
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
132
+
133
+ @triton_heuristics.template(
134
+
135
+ num_stages=3,
136
+ num_warps=8,
137
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
138
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
139
+
140
+ )
141
+ @triton.jit
142
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
143
+ PRESCALE_QK : tl.constexpr = False
144
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
145
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
146
+ WRITE_DQ : tl.constexpr = True
147
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
148
+ OUTPUT_MAX : tl.constexpr = False
149
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
150
+ IS_DIVISIBLE : tl.constexpr = False
151
+ SM_SCALE : tl.constexpr = 0.08838834764831845
152
+ GQA_SHARED_HEADS : tl.constexpr = 4
153
+ HAS_FULL_BLOCKS : tl.constexpr = True
154
+ QK_HEAD_DIM : tl.constexpr = 128
155
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ V_HEAD_DIM : tl.constexpr = 128
157
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
158
+ SAFE_HEAD_DIM : tl.constexpr = True
159
+ BLOCK_M1 : tl.constexpr = 64
160
+ BLOCK_N1 : tl.constexpr = 128
161
+ BLOCK_M2 : tl.constexpr = 128
162
+ BLOCK_N2 : tl.constexpr = 64
163
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
164
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
165
+ INDEX_DTYPE : tl.constexpr = tl.int32
166
+ Q = arg_Q
167
+ K = arg_K
168
+ V = arg_V
169
+ LSE = arg_LSE
170
+ DELTA = arg_DELTA
171
+ DO = arg_DO
172
+ DQ = arg_DQ
173
+ DV = arg_DV
174
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
175
+ KV_IDX = arg_KV_IDX
176
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
177
+ Q_IDX = arg_Q_IDX
178
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
179
+ FULL_KV_IDX = arg_FULL_KV_IDX
180
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
181
+ FULL_Q_IDX = arg_FULL_Q_IDX
182
+
183
+ # Sub notation for this kernel:
184
+ #
185
+ # Q: Query, K: Key, V: Value
186
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
187
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
188
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
189
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
190
+ # inductor codegen
191
+ # M: Number of queries, N: Number of keys/values
192
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
193
+ # V_HEAD_DIM: The dimension of the value embeddings
194
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
195
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
196
+ # (Modifiable) Performance tuning options
197
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
198
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
199
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
200
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
201
+ #
202
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
203
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
204
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
205
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
206
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
207
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
208
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
209
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
210
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
211
+
212
+ # The below are kernel options that can be applied for certain score_mods,
213
+ # or involve a numerics vs. perf tradeoff
214
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
215
+ # about 20% more numerical error, but slightly faster.
216
+
217
+ # Define strides of inputs
218
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
219
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
220
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
221
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
222
+
223
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
224
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
225
+
226
+ ZQ = 1
227
+ HQ = 32
228
+ HKV = 8
229
+ Q_LEN = ks0
230
+ ZKV = 1
231
+ KV_LEN = ks1
232
+
233
+ MATMUL_PRECISION = Q.dtype.element_ty
234
+
235
+ pid = tl.program_id(0).to(INDEX_DTYPE)
236
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
237
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
238
+
239
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
240
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
241
+ off_zkv = off_zq % ZKV # kv batch idx
242
+
243
+ SPARSE_Z = 1
244
+ SPARSE_HQ = 1
245
+
246
+ sparse_idx_z = off_zq % SPARSE_Z
247
+
248
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
249
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
250
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
251
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
252
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
253
+
254
+ # offset K, V, DV pointers for batch/kv-head
255
+ K += k_adj
256
+ V += v_adj
257
+ DV += dv_adj
258
+
259
+ RCP_LN2 = 1.44269504
260
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
261
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
262
+
263
+ if pid >= NUM_KV_BLOCKS:
264
+ off_pid = pid - NUM_KV_BLOCKS
265
+ # THIS BLOCK DOES DQ
266
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
267
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
268
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
269
+ start_m2_block = off_pid % NUM_Q_BLOCKS
270
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
271
+ stride_kv_num_blks_h = ks2
272
+ stride_kv_idx_h = ks3*ks4
273
+ stride_kv_idx_m = ks4
274
+
275
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
276
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
277
+
278
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
279
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
280
+
281
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
282
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
283
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
284
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
285
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
286
+
287
+ Q2 = Q + q_adj2
288
+ DO2 = DO + do_adj2
289
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
290
+ # if Q is broadcasted)
291
+ DQ2 = DQ + dq_adj2
292
+ LSE2 = LSE + off_chz2
293
+ DELTA2 = DELTA + off_chz2
294
+
295
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
296
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
297
+
298
+ start_m2 = start_m2_block * BLOCK_M2
299
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
300
+
301
+ # load Q and do: they stay in SRAM throughout the inner loop.
302
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
303
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
304
+
305
+ if PRESCALE_QK:
306
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
307
+
308
+ if IS_DIVISIBLE:
309
+ Di = tl.load(DELTA2 + offs_m2)
310
+ lse = tl.load(LSE2 + offs_m2)
311
+ else:
312
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
313
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
314
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
315
+ lse = lse[:, None]
316
+
317
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
319
+ kv_indices = KV_IDX + sparse_kv_idx_offset
320
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
321
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
322
+
323
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
324
+ dq = bwd_dq_inner(
325
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
326
+ K, V,
327
+ dq, q, do, Di, lse,
328
+ off_zq, off_hq2, offs_m2, offs_n2,
329
+ stride_kn, stride_kd, stride_vn, stride_vd,
330
+ kv_indices, sparse_kv_num_blocks,
331
+ MATMUL_PRECISION,
332
+ IS_FULL_BLOCKS=False,
333
+ )
334
+
335
+ if HAS_FULL_BLOCKS:
336
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
337
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
338
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
339
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
340
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
341
+
342
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
343
+ dq = bwd_dq_inner(
344
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
345
+ K, V,
346
+ dq, q, do, Di, lse,
347
+ off_zq, off_hq2, offs_m2, offs_n2,
348
+ stride_kn, stride_kd, stride_vn, stride_vd,
349
+ kv_indices, sparse_kv_num_blocks,
350
+ MATMUL_PRECISION,
351
+ IS_FULL_BLOCKS=True,
352
+ )
353
+
354
+ # Write back dQ.
355
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
356
+ dq *= SM_SCALE
357
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
358
+ tl.store(dq_ptrs, dq)
359
+ else:
360
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
361
+ else:
362
+ # THIS BLOCK DOES DK & DV
363
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
364
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
365
+
366
+ pid_mask = pid // SPARSE_KV_MULTIPLE
367
+
368
+ stride_q_num_blks_h = 13
369
+ stride_q_idx_h = 169
370
+ stride_q_idx_n = 13
371
+
372
+
373
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
374
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
375
+
376
+ start_n1 = pid * BLOCK_N1
377
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
378
+
379
+ # load K and V: they stay in SRAM throughout the inner loop.
380
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
381
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
382
+
383
+ if PRESCALE_QK:
384
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
385
+
386
+ for off_g in range(0, GQA_SHARED_HEADS):
387
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
388
+
389
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
390
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
391
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
392
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
393
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
394
+
395
+ Q1 = Q + q_adj1
396
+ DO1 = DO + do_adj1
397
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
398
+ # if Q is broadcasted)
399
+ LSE1 = LSE + off_chz1
400
+ DELTA1 = DELTA + off_chz1
401
+
402
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
403
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
404
+
405
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
406
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
407
+
408
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
409
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
410
+ q_indices = Q_IDX + sparse_q_idx_offset
411
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
412
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
413
+
414
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
415
+ dk, dv = bwd_dkdv_inner(
416
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
417
+ Q1, DO1, DELTA1, LSE1,
418
+ dk, dv, k, v,
419
+ off_zq, off_hq1, offs_n1, offs_m1,
420
+ stride_qm, stride_qd, stride_dom, stride_dod,
421
+ q_indices, sparse_q_num_blocks,
422
+ MATMUL_PRECISION,
423
+ IS_FULL_BLOCKS=False,
424
+ )
425
+
426
+
427
+ if HAS_FULL_BLOCKS:
428
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
429
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
430
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
431
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
432
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
433
+
434
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
435
+ dk, dv = bwd_dkdv_inner(
436
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
437
+ Q1, DO1, DELTA1, LSE1,
438
+ dk, dv, k, v,
439
+ off_zq, off_hq1, offs_n1, offs_m1,
440
+ stride_qm, stride_qd, stride_dom, stride_dod,
441
+ q_indices, sparse_q_num_blocks,
442
+ MATMUL_PRECISION,
443
+ IS_FULL_BLOCKS=True,
444
+ )
445
+
446
+ # Write back dV and dK.
447
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
448
+
449
+ index_n = offs_n1[:, None]
450
+ index_k = offs_k[None, :]
451
+ index_v = offs_v[None, :]
452
+
453
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
454
+ tl.store(dv_ptrs, dv)
455
+ else:
456
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
457
+
458
+ dk *= SM_SCALE
459
+
460
+ if SAFE_HEAD_DIM:
461
+ mask = index_n < KV_LEN
462
+ else:
463
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
464
+
465
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
466
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
467
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
468
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
469
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
470
+
471
+ @triton.jit
472
+ def bwd_dq_inner(
473
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
474
+ K, V, # pointers
475
+ dq, q, do, Di, lse,
476
+ off_z, off_hq, offs_m2, offs_n2,
477
+ stride_kn, stride_kd, stride_vn, stride_vd,
478
+ kv_indices, sparse_kv_num_blocks,
479
+ MATMUL_PRECISION,
480
+ IS_FULL_BLOCKS,
481
+ ):
482
+ PRESCALE_QK : tl.constexpr = False
483
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
484
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
485
+ WRITE_DQ : tl.constexpr = True
486
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
487
+ OUTPUT_MAX : tl.constexpr = False
488
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
489
+ IS_DIVISIBLE : tl.constexpr = False
490
+ SM_SCALE : tl.constexpr = 0.08838834764831845
491
+ GQA_SHARED_HEADS : tl.constexpr = 4
492
+ HAS_FULL_BLOCKS : tl.constexpr = True
493
+ QK_HEAD_DIM : tl.constexpr = 128
494
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ V_HEAD_DIM : tl.constexpr = 128
496
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
497
+ SAFE_HEAD_DIM : tl.constexpr = True
498
+ BLOCK_M1 : tl.constexpr = 64
499
+ BLOCK_N1 : tl.constexpr = 128
500
+ BLOCK_M2 : tl.constexpr = 128
501
+ BLOCK_N2 : tl.constexpr = 64
502
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
503
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
504
+ INDEX_DTYPE : tl.constexpr = tl.int32
505
+
506
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
507
+ RCP_LN2: tl.constexpr = 1.44269504
508
+ Q_LEN = ks0
509
+ KV_LEN = ks1
510
+
511
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
512
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
513
+
514
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
515
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
516
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
517
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
518
+
519
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
520
+
521
+ for start_n in range(0, hi):
522
+ dq = bwd_dq_block_mn(
523
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
524
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
525
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
526
+ stride_kn, stride_kd, stride_vn, stride_vd,
527
+ kv_indices, sparse_kv_num_blocks,
528
+ MATMUL_PRECISION, RCP_LN2,
529
+ IS_FULL_BLOCKS,
530
+ )
531
+
532
+ # Increment pointers.
533
+ offset = get_offset_for_next_block(
534
+ start_n, kv_indices, sparse_kv_num_blocks,
535
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
536
+ )
537
+
538
+ kT_ptrs += offset * stride_kn
539
+ vT_ptrs += offset * stride_vn
540
+
541
+ offs_n2 += offset
542
+
543
+ return dq
544
+
545
+
546
+ @triton.jit
547
+ def bwd_dq_block_mn(
548
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
549
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
550
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
551
+ stride_kn, stride_kd, stride_vn, stride_vd,
552
+ kv_indices, sparse_kv_num_blocks,
553
+ MATMUL_PRECISION, RCP_LN2,
554
+ IS_FULL_BLOCKS,
555
+ ):
556
+ PRESCALE_QK : tl.constexpr = False
557
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
558
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
559
+ WRITE_DQ : tl.constexpr = True
560
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
561
+ OUTPUT_MAX : tl.constexpr = False
562
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
563
+ IS_DIVISIBLE : tl.constexpr = False
564
+ SM_SCALE : tl.constexpr = 0.08838834764831845
565
+ GQA_SHARED_HEADS : tl.constexpr = 4
566
+ HAS_FULL_BLOCKS : tl.constexpr = True
567
+ QK_HEAD_DIM : tl.constexpr = 128
568
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ V_HEAD_DIM : tl.constexpr = 128
570
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
571
+ SAFE_HEAD_DIM : tl.constexpr = True
572
+ BLOCK_M1 : tl.constexpr = 64
573
+ BLOCK_N1 : tl.constexpr = 128
574
+ BLOCK_M2 : tl.constexpr = 128
575
+ BLOCK_N2 : tl.constexpr = 64
576
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
577
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
578
+ INDEX_DTYPE : tl.constexpr = tl.int32
579
+
580
+
581
+ # NB reversed order to since K is transposed
582
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
583
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
584
+ if not PRESCALE_QK:
585
+ qk *= SM_SCALE
586
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
587
+ pre_mod_scores = qk
588
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
589
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
590
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
591
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
592
+
593
+ tmp0 = (qk)
594
+ post_mod_scores = tmp0
595
+
596
+
597
+
598
+
599
+ if not IS_DIVISIBLE:
600
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
601
+
602
+ if not IS_FULL_BLOCKS:
603
+ tmp1 = (m)
604
+ tmp2 = tl.full([1], 0, tl.int32)
605
+ tmp3 = tmp1 < tmp2
606
+ tmp4 = (n)
607
+ tmp5 = tmp4 <= tmp1
608
+ tmp6 = tmp3 & tmp5
609
+ tmp7 = tmp1 >= tmp2
610
+ tmp8 = tmp4 < tmp2
611
+ tmp9 = tmp7 & tmp8
612
+ tmp10 = tmp8 == 0
613
+ tmp11 = tmp7 & tmp10
614
+ tmp12 = tmp1 - tmp2
615
+ tmp13 = tl.full([1], 16, tl.int32)
616
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
617
+ tmp15 = tmp4 - tmp2
618
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
619
+ tmp17 = tmp14 == tmp16
620
+ tmp18 = tmp11 & tmp17
621
+ tmp19 = tmp9 | tmp18
622
+ tmp20 = tmp6 | tmp19
623
+ mask_mod_output = tmp20
624
+
625
+
626
+ # apply mask for partial masked block
627
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
628
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
629
+ if not PRESCALE_QK:
630
+ post_mod_scores *= RCP_LN2
631
+ p = tl.math.exp2(post_mod_scores - lse)
632
+ # Compute dP and dS.
633
+ # NB reversed order to since V is transposed
634
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
635
+
636
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
637
+ ds = p * (dp - Di[:, None])
638
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
639
+ tmp21 = (ds)
640
+ grad_scores = tmp21
641
+
642
+
643
+ if not IS_DIVISIBLE:
644
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
645
+
646
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
647
+ if WRITE_DQ:
648
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
649
+
650
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
651
+ ds = grad_scores
652
+
653
+ if not IS_FULL_BLOCKS:
654
+ # (grads) apply mask for partially unmasked block
655
+ ds = tl.where(mask_mod_output, ds, 0.0)
656
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
657
+ ds = ds.to(MATMUL_PRECISION)
658
+ # Compute dQ.
659
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
660
+
661
+ return dq
662
+
663
+
664
+ @triton.jit
665
+ def bwd_dkdv_inner(
666
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
667
+ Q, DO, DELTA, LSE, # pointers
668
+ dk, dv, k, v,
669
+ off_z, off_hq, offs_n1, offs_m1,
670
+ stride_qm, stride_qd, stride_dom, stride_dod,
671
+ q_indices, sparse_q_num_blocks,
672
+ MATMUL_PRECISION,
673
+ IS_FULL_BLOCKS,
674
+ ):
675
+ PRESCALE_QK : tl.constexpr = False
676
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
677
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
678
+ WRITE_DQ : tl.constexpr = True
679
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
680
+ OUTPUT_MAX : tl.constexpr = False
681
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
682
+ IS_DIVISIBLE : tl.constexpr = False
683
+ SM_SCALE : tl.constexpr = 0.08838834764831845
684
+ GQA_SHARED_HEADS : tl.constexpr = 4
685
+ HAS_FULL_BLOCKS : tl.constexpr = True
686
+ QK_HEAD_DIM : tl.constexpr = 128
687
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
688
+ V_HEAD_DIM : tl.constexpr = 128
689
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
690
+ SAFE_HEAD_DIM : tl.constexpr = True
691
+ BLOCK_M1 : tl.constexpr = 64
692
+ BLOCK_N1 : tl.constexpr = 128
693
+ BLOCK_M2 : tl.constexpr = 128
694
+ BLOCK_N2 : tl.constexpr = 64
695
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
696
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
697
+ INDEX_DTYPE : tl.constexpr = tl.int32
698
+
699
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
700
+ RCP_LN2: tl.constexpr = 1.44269504
701
+ Q_LEN = ks0
702
+ KV_LEN = ks1
703
+
704
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
705
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
706
+
707
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
708
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
709
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
710
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
711
+
712
+ # The minimum is needed to handle the case where we run with a super large
713
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
714
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
715
+
716
+ for start_m in range(0, hi):
717
+ dk, dv = bwd_dkdv_block_mn(
718
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
719
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
720
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
721
+ stride_qm, stride_qd, stride_dom, stride_dod,
722
+ q_indices, sparse_q_num_blocks,
723
+ MATMUL_PRECISION, RCP_LN2,
724
+ IS_FULL_BLOCKS,
725
+ )
726
+ # Increment pointers.
727
+ offset = get_offset_for_next_block(
728
+ start_m, q_indices, sparse_q_num_blocks,
729
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
730
+ )
731
+
732
+ qT_ptrs += offset * stride_qm
733
+ do_ptrs += offset * stride_dom
734
+ offs_m1 += offset
735
+
736
+ return dk, dv
737
+
738
+
739
+ @triton.jit
740
+ def bwd_dkdv_block_mn(
741
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
742
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
743
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
744
+ stride_qm, stride_qd, stride_dom, stride_dod,
745
+ q_indices, sparse_q_num_blocks,
746
+ MATMUL_PRECISION, RCP_LN2,
747
+ IS_FULL_BLOCKS,
748
+ ):
749
+ PRESCALE_QK : tl.constexpr = False
750
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
751
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
752
+ WRITE_DQ : tl.constexpr = True
753
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
754
+ OUTPUT_MAX : tl.constexpr = False
755
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
756
+ IS_DIVISIBLE : tl.constexpr = False
757
+ SM_SCALE : tl.constexpr = 0.08838834764831845
758
+ GQA_SHARED_HEADS : tl.constexpr = 4
759
+ HAS_FULL_BLOCKS : tl.constexpr = True
760
+ QK_HEAD_DIM : tl.constexpr = 128
761
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
762
+ V_HEAD_DIM : tl.constexpr = 128
763
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
764
+ SAFE_HEAD_DIM : tl.constexpr = True
765
+ BLOCK_M1 : tl.constexpr = 64
766
+ BLOCK_N1 : tl.constexpr = 128
767
+ BLOCK_M2 : tl.constexpr = 128
768
+ BLOCK_N2 : tl.constexpr = 64
769
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
770
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
771
+ INDEX_DTYPE : tl.constexpr = tl.int32
772
+
773
+
774
+ # NB reversed order since Q is transposed
775
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
776
+ # Load LSE before computing qk to reduce pipeline stall.
777
+ if IS_DIVISIBLE:
778
+ lse = tl.load(LSE + offs_m1)
779
+ else:
780
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
781
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
782
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
783
+ if not PRESCALE_QK:
784
+ qkT *= SM_SCALE
785
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
786
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
787
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
788
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
789
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
790
+
791
+ pre_mod_scores = qkT
792
+ tmp22 = (qkT)
793
+ post_mod_scores = tmp22
794
+
795
+
796
+
797
+ if not IS_DIVISIBLE:
798
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
799
+
800
+ if not IS_FULL_BLOCKS:
801
+ tmp23 = (m)
802
+ tmp24 = tl.full([1], 0, tl.int32)
803
+ tmp25 = tmp23 < tmp24
804
+ tmp26 = (n)
805
+ tmp27 = tmp26 <= tmp23
806
+ tmp28 = tmp25 & tmp27
807
+ tmp29 = tmp23 >= tmp24
808
+ tmp30 = tmp26 < tmp24
809
+ tmp31 = tmp29 & tmp30
810
+ tmp32 = tmp30 == 0
811
+ tmp33 = tmp29 & tmp32
812
+ tmp34 = tmp23 - tmp24
813
+ tmp35 = tl.full([1], 16, tl.int32)
814
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
815
+ tmp37 = tmp26 - tmp24
816
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
817
+ tmp39 = tmp36 == tmp38
818
+ tmp40 = tmp33 & tmp39
819
+ tmp41 = tmp31 | tmp40
820
+ tmp42 = tmp28 | tmp41
821
+ mask_mod_output = tmp42
822
+
823
+ # (grads) apply mask for fully masked block
824
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
825
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
826
+ if not PRESCALE_QK:
827
+ post_mod_scores *= RCP_LN2
828
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
829
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
830
+ # Compute dV.
831
+ ppT = pT
832
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
833
+ if IS_DIVISIBLE:
834
+ Di = tl.load(DELTA + offs_m1)
835
+ else:
836
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
837
+ # Compute dP and dS.
838
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
839
+ dsT = pT * (dpT - Di[None, :])
840
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
841
+ tmp43 = (dsT)
842
+ grad_scores = tmp43
843
+
844
+
845
+
846
+ if not IS_DIVISIBLE:
847
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
848
+
849
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
850
+ if not WRITE_DQ:
851
+ idx_b = off_z
852
+ idx_h = off_hq
853
+ idx_m = m
854
+ idx_n = n
855
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
856
+
857
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
858
+ dsT = grad_scores
859
+ if not IS_FULL_BLOCKS:
860
+ # (grads) apply mask for partially unmasked block
861
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
862
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
863
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
864
+
865
+ return dk, dv
866
+
867
+ # Utility triton funcs
868
+ @triton.jit
869
+ def get_offset_for_next_block(
870
+ loop_iter, col_indices, total_blocks,
871
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
872
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
873
+ ):
874
+ if BLOCKS_ARE_CONTIGUOUS:
875
+ return BLOCK
876
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
877
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
878
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
879
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
880
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
881
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
882
+ return offset
883
+
884
+ @triton.jit
885
+ def get_bounded_indices(indices, max_len=None):
886
+ return indices % max_len if max_len is not None else indices
887
+
888
+ @triton.jit
889
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
890
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
891
+ return tl.load(block_ptr)
892
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
893
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
894
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
895
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
896
+ else:
897
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
898
+
899
+ @triton.jit
900
+ def load_checked_2d(
901
+ ptr,
902
+ offs_m,
903
+ offs_n,
904
+ stride_m,
905
+ stride_n,
906
+ IS_DIVISIBLE_M: tl.constexpr,
907
+ IS_DIVISIBLE_N: tl.constexpr,
908
+ M_LEN: tl.constexpr,
909
+ N_LEN: tl.constexpr,
910
+ ):
911
+ # Calculate final pointer if strides are provided
912
+ if stride_m is not None and stride_n is not None:
913
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
914
+
915
+ # Handle all masking cases
916
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
917
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
918
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
919
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
920
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
921
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
922
+ else: # Both divisible
923
+ return tl.load(ptr)
924
+ ''', device_str='cuda')
925
+
926
+
927
+ async_compile.wait(globals())
928
+ del async_compile
929
+
930
+ class Runner:
931
+ def __init__(self, partitions):
932
+ self.partitions = partitions
933
+
934
+ def recursively_apply_fns(self, fns):
935
+ new_callables = []
936
+ for fn, c in zip(fns, self.partitions):
937
+ new_callables.append(fn(c))
938
+ self.partitions = new_callables
939
+
940
+ def call(self, args):
941
+ primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args
942
+ args.clear()
943
+ s37 = primals_10
944
+ s0 = primals_11
945
+ s22 = primals_7
946
+ s72 = primals_8
947
+ s99 = primals_12
948
+ assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
949
+ assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
950
+ assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
951
+ assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
952
+ assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1))
953
+ assert_size_stride(primals_14, (1, 1, 13), (13, 13, 1))
954
+ assert_size_stride(primals_15, (1, 1, 13, 13), (169, 169, 13, 1))
955
+ assert_size_stride(primals_16, (1, 1, 13), (13, 13, 1))
956
+ assert_size_stride(primals_17, (1, 1, 13, 13), (169, 169, 13, 1))
957
+ assert_size_stride(primals_18, (1, 1, 13), (13, 13, 1))
958
+ assert_size_stride(primals_19, (1, 1, 13, 13), (169, 169, 13, 1))
959
+ assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
960
+ assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
961
+ assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
962
+ assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
963
+ with torch.cuda._DeviceGuard(6):
964
+ torch.cuda.set_device(6)
965
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
966
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
967
+ triton_red_fused_mul_0_xnumel = 32*s37
968
+ stream6 = get_raw_stream(6)
969
+ triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream6)
970
+ del getitem
971
+ del tangents_2
972
+ buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
973
+ buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
974
+ buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
975
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
976
+ stream6 = get_raw_stream(6)
977
+ triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream6)
978
+ del buf1
979
+ del getitem_1
980
+ del primals_13
981
+ del primals_14
982
+ del primals_15
983
+ del primals_16
984
+ del primals_17
985
+ del primals_18
986
+ del primals_19
987
+ del primals_2
988
+ del primals_4
989
+ del primals_6
990
+ del primals_9
991
+ del tangents_1
992
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, )
993
+
994
+ runner = Runner(partitions=[])
995
+ call = runner.call
996
+ recursively_apply_fns = runner.recursively_apply_fns
997
+
998
+
999
+ def benchmark_compiled_module(times=10, repeat=10):
1000
+ from torch._dynamo.testing import rand_strided
1001
+ from torch._inductor.utils import print_performance
1002
+ primals_10 = 1568
1003
+ primals_11 = 1568
1004
+ primals_7 = 13
1005
+ primals_8 = 13
1006
+ primals_12 = 13
1007
+ primals_2 = rand_strided((1, 32, 1568, 128), (6422528, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
1008
+ primals_4 = rand_strided((1, 8, 1568, 128), (1605632, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16)
1009
+ primals_6 = rand_strided((1, 8, 1568, 128), (1605632, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16)
1010
+ primals_9 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
1011
+ primals_13 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
1012
+ primals_14 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
1013
+ primals_15 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
1014
+ primals_16 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
1015
+ primals_17 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
1016
+ primals_18 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
1017
+ primals_19 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
1018
+ getitem = rand_strided((1, 32, 1568, 128), (6422528, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
1019
+ getitem_1 = rand_strided((1, 32, 1568), (50176, 1568, 1), device='cuda:6', dtype=torch.float32)
1020
+ tangents_1 = rand_strided((1, 32, 1568, 128), (6422528, 200704, 128, 1), device='cuda:6', dtype=torch.bfloat16)
1021
+ tangents_2 = rand_strided((1, 32, 1568), (50176, 1568, 1), device='cuda:6', dtype=torch.float32)
1022
+ fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2])
1023
+ return print_performance(fn, times=times, repeat=repeat)
1024
+
1025
+
1026
+ if __name__ == "__main__":
1027
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1028
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2m/50bb8a1cf8ca03a72155f9f84fae7c10cfe881f9f7d97e1e3f94ec85776c3639.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"}
progress/SpecForge/cache/compiled_kernels/2m/c2mzrxons6jvrj3mv77db5xyv5m4z73mego5v77sl66o6wuh5dbk.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 65536},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x0 = (xindex % ks0)
24
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
25
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
26
+ tmp1 = 0.6931471805599453
27
+ tmp2 = tmp0 * tmp1
28
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/2p/c2pjjgigeh2ro4r74dzlvlf7os5rhnmyche5rzwawor6zxb6rvk2.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['1_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/t3/ct3idxec76qtl3vshvvwnyu6rebkbok6z5sogtvlz6a62sxeduqp.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 219136, 128, 1]cuda:2" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[1, 32, 1712][55296, 1728, 1]cuda:2" = PlaceHolder[target=buf0]
44
+ # %tangents_2 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=tangents_2]
45
+ # %mul_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
46
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1712, 1712, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
47
+ # return %buf0,%buf1
48
+ triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', '''
49
+ import triton
50
+ import triton.language as tl
51
+
52
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
53
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
54
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
55
+ triton_helpers.set_driver_to_gpu()
56
+
57
+ @triton_heuristics.reduction(
58
+ size_hints={'x': 65536, 'r0_': 128},
59
+ reduction_hint=ReductionHint.DEFAULT,
60
+ filename=__file__,
61
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
62
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 657408, 'r0_': 28049408}}
63
+ )
64
+ @triton.jit
65
+ def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
66
+ xnumel = 54784
67
+ r0_numel = 128
68
+ rnumel = r0_numel
69
+ RBLOCK: tl.constexpr = R0_BLOCK
70
+ xoffset = tl.program_id(0) * XBLOCK
71
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
72
+ xmask = xindex < xnumel
73
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
74
+ rbase = r0_base
75
+ x0 = (xindex % 1712)
76
+ x1 = xindex // 1712
77
+ x3 = xindex
78
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_2 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last')
93
+ tmp6 = tmp4.to(tl.float32)
94
+ tmp8 = 0.6931471805599453
95
+ tmp9 = tmp7 * tmp8
96
+ tmp10 = 1.4426950408889634
97
+ tmp11 = tmp9 * tmp10
98
+ tmp12 = tmp6 - tmp11
99
+ tl.store(out_ptr1 + (x3), tmp12, xmask)
100
+ ''', device_str='cuda')
101
+
102
+
103
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xs/cxsqqwycgcgtkcz7s77tzmxqhofti6tljwwkkkb7agpib7z2otzr.py
104
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
105
+ # Source node to ATen node mapping:
106
+ # Graph fragment:
107
+ # %primals_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1]
108
+ # %primals_2 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_2]
109
+ # %primals_3 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_3]
110
+ # %getitem_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=getitem_1]
111
+ # %buf1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=buf1]
112
+ # %tangents_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 219136, 128, 1]cuda:2" = PlaceHolder[target=tangents_1]
113
+ # %getitem_3 : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3]
114
+ # %getitem_5 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=getitem_5]
115
+ # %primals_5 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_5]
116
+ # %primals_4 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_4]
117
+ # %primals_8 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_8]
118
+ # %primals_9 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_9]
119
+ # %primals_6 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_6]
120
+ # %primals_7 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_7]
121
+ # %primals_10 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_10]
122
+ # %primals_11 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_11]
123
+ # %mul_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
124
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1712, 1712, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
125
+ # return %getitem_4
126
+ triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
127
+ import triton
128
+ import triton.language as tl
129
+
130
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
131
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
132
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
133
+
134
+ @triton_heuristics.template(
135
+
136
+ num_stages=3,
137
+ num_warps=8,
138
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
139
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
140
+
141
+ )
142
+ @triton.jit
143
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
144
+ PRESCALE_QK : tl.constexpr = False
145
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
146
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
147
+ WRITE_DQ : tl.constexpr = True
148
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
149
+ OUTPUT_MAX : tl.constexpr = False
150
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
151
+ IS_DIVISIBLE : tl.constexpr = False
152
+ SM_SCALE : tl.constexpr = 0.08838834764831845
153
+ GQA_SHARED_HEADS : tl.constexpr = 4
154
+ HAS_FULL_BLOCKS : tl.constexpr = True
155
+ QK_HEAD_DIM : tl.constexpr = 128
156
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
157
+ V_HEAD_DIM : tl.constexpr = 128
158
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
159
+ SAFE_HEAD_DIM : tl.constexpr = True
160
+ BLOCK_M1 : tl.constexpr = 64
161
+ BLOCK_N1 : tl.constexpr = 128
162
+ BLOCK_M2 : tl.constexpr = 128
163
+ BLOCK_N2 : tl.constexpr = 64
164
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
165
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
166
+ INDEX_DTYPE : tl.constexpr = tl.int32
167
+ Q = arg_Q
168
+ K = arg_K
169
+ V = arg_V
170
+ LSE = arg_LSE
171
+ DELTA = arg_DELTA
172
+ DO = arg_DO
173
+ DQ = arg_DQ
174
+ DV = arg_DV
175
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
176
+ KV_IDX = arg_KV_IDX
177
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
178
+ Q_IDX = arg_Q_IDX
179
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
180
+ FULL_KV_IDX = arg_FULL_KV_IDX
181
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
182
+ FULL_Q_IDX = arg_FULL_Q_IDX
183
+
184
+ # Sub notation for this kernel:
185
+ #
186
+ # Q: Query, K: Key, V: Value
187
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
188
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
189
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
190
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
191
+ # inductor codegen
192
+ # M: Number of queries, N: Number of keys/values
193
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
194
+ # V_HEAD_DIM: The dimension of the value embeddings
195
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
196
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
197
+ # (Modifiable) Performance tuning options
198
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
199
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
200
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
201
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
202
+ #
203
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
204
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
205
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
206
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
207
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
208
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
209
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
210
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
211
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
212
+
213
+ # The below are kernel options that can be applied for certain score_mods,
214
+ # or involve a numerics vs. perf tradeoff
215
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
216
+ # about 20% more numerical error, but slightly faster.
217
+
218
+ # Define strides of inputs
219
+ stride_qz, stride_qh, stride_qm, stride_qd = 7012352, 128, 4096, 1
220
+ stride_kz, stride_kh, stride_kn, stride_kd = 1753088, 128, 1024, 1
221
+ stride_vz, stride_vh, stride_vn, stride_vd = 1753088, 128, 1024, 1
222
+ stride_doz, stride_doh, stride_dom, stride_dod = 7012352, 219136, 128, 1
223
+
224
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 7012352, 128, 4096, 1
225
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1753088, 128, 1024, 1
226
+
227
+ ZQ = 1
228
+ HQ = 32
229
+ HKV = 8
230
+ Q_LEN = 1712
231
+ ZKV = 1
232
+ KV_LEN = 1712
233
+
234
+ MATMUL_PRECISION = Q.dtype.element_ty
235
+
236
+ pid = tl.program_id(0).to(INDEX_DTYPE)
237
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
238
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
239
+
240
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
241
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
242
+ off_zkv = off_zq % ZKV # kv batch idx
243
+
244
+ SPARSE_Z = 1
245
+ SPARSE_HQ = 1
246
+
247
+ sparse_idx_z = off_zq % SPARSE_Z
248
+
249
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
250
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
251
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
252
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
253
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
254
+
255
+ # offset K, V, DV pointers for batch/kv-head
256
+ K += k_adj
257
+ V += v_adj
258
+ DV += dv_adj
259
+
260
+ RCP_LN2 = 1.44269504
261
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
262
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
263
+
264
+ if pid >= NUM_KV_BLOCKS:
265
+ off_pid = pid - NUM_KV_BLOCKS
266
+ # THIS BLOCK DOES DQ
267
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
268
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
269
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
270
+ start_m2_block = off_pid % NUM_Q_BLOCKS
271
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
272
+ stride_kv_num_blks_h = 14
273
+ stride_kv_idx_h = 196
274
+ stride_kv_idx_m = 14
275
+
276
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
277
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
278
+
279
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
280
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
281
+
282
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
283
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
284
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
285
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
286
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
287
+
288
+ Q2 = Q + q_adj2
289
+ DO2 = DO + do_adj2
290
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
291
+ # if Q is broadcasted)
292
+ DQ2 = DQ + dq_adj2
293
+ LSE2 = LSE + off_chz2
294
+ DELTA2 = DELTA + off_chz2
295
+
296
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
297
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
298
+
299
+ start_m2 = start_m2_block * BLOCK_M2
300
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
301
+
302
+ # load Q and do: they stay in SRAM throughout the inner loop.
303
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
304
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
305
+
306
+ if PRESCALE_QK:
307
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
308
+
309
+ if IS_DIVISIBLE:
310
+ Di = tl.load(DELTA2 + offs_m2)
311
+ lse = tl.load(LSE2 + offs_m2)
312
+ else:
313
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
314
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
315
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
316
+ lse = lse[:, None]
317
+
318
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
319
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
320
+ kv_indices = KV_IDX + sparse_kv_idx_offset
321
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
322
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
323
+
324
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
325
+ dq = bwd_dq_inner(
326
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
327
+ K, V,
328
+ dq, q, do, Di, lse,
329
+ off_zq, off_hq2, offs_m2, offs_n2,
330
+ stride_kn, stride_kd, stride_vn, stride_vd,
331
+ kv_indices, sparse_kv_num_blocks,
332
+ MATMUL_PRECISION,
333
+ IS_FULL_BLOCKS=False,
334
+ )
335
+
336
+ if HAS_FULL_BLOCKS:
337
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
338
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
339
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
340
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
341
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
342
+
343
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
344
+ dq = bwd_dq_inner(
345
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
346
+ K, V,
347
+ dq, q, do, Di, lse,
348
+ off_zq, off_hq2, offs_m2, offs_n2,
349
+ stride_kn, stride_kd, stride_vn, stride_vd,
350
+ kv_indices, sparse_kv_num_blocks,
351
+ MATMUL_PRECISION,
352
+ IS_FULL_BLOCKS=True,
353
+ )
354
+
355
+ # Write back dQ.
356
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
357
+ dq *= SM_SCALE
358
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
359
+ tl.store(dq_ptrs, dq)
360
+ else:
361
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
362
+ else:
363
+ # THIS BLOCK DOES DK & DV
364
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
365
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
366
+
367
+ pid_mask = pid // SPARSE_KV_MULTIPLE
368
+
369
+ stride_q_num_blks_h = 14
370
+ stride_q_idx_h = 196
371
+ stride_q_idx_n = 14
372
+
373
+
374
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
375
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
376
+
377
+ start_n1 = pid * BLOCK_N1
378
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
379
+
380
+ # load K and V: they stay in SRAM throughout the inner loop.
381
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
382
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
383
+
384
+ if PRESCALE_QK:
385
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
386
+
387
+ for off_g in range(0, GQA_SHARED_HEADS):
388
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
389
+
390
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
391
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
392
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
393
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
394
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
395
+
396
+ Q1 = Q + q_adj1
397
+ DO1 = DO + do_adj1
398
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
399
+ # if Q is broadcasted)
400
+ LSE1 = LSE + off_chz1
401
+ DELTA1 = DELTA + off_chz1
402
+
403
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
404
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
405
+
406
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
407
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
408
+
409
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
410
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
411
+ q_indices = Q_IDX + sparse_q_idx_offset
412
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
413
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
414
+
415
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
416
+ dk, dv = bwd_dkdv_inner(
417
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
418
+ Q1, DO1, DELTA1, LSE1,
419
+ dk, dv, k, v,
420
+ off_zq, off_hq1, offs_n1, offs_m1,
421
+ stride_qm, stride_qd, stride_dom, stride_dod,
422
+ q_indices, sparse_q_num_blocks,
423
+ MATMUL_PRECISION,
424
+ IS_FULL_BLOCKS=False,
425
+ )
426
+
427
+
428
+ if HAS_FULL_BLOCKS:
429
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
430
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
431
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
432
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
433
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
434
+
435
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
436
+ dk, dv = bwd_dkdv_inner(
437
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
438
+ Q1, DO1, DELTA1, LSE1,
439
+ dk, dv, k, v,
440
+ off_zq, off_hq1, offs_n1, offs_m1,
441
+ stride_qm, stride_qd, stride_dom, stride_dod,
442
+ q_indices, sparse_q_num_blocks,
443
+ MATMUL_PRECISION,
444
+ IS_FULL_BLOCKS=True,
445
+ )
446
+
447
+ # Write back dV and dK.
448
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
449
+
450
+ index_n = offs_n1[:, None]
451
+ index_k = offs_k[None, :]
452
+ index_v = offs_v[None, :]
453
+
454
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
455
+ tl.store(dv_ptrs, dv)
456
+ else:
457
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
458
+
459
+ dk *= SM_SCALE
460
+
461
+ if SAFE_HEAD_DIM:
462
+ mask = index_n < KV_LEN
463
+ else:
464
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
465
+
466
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
467
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
468
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
469
+ xindex = index_k + 128*index_n + 219136*off_hkv + 1753088*off_zq
470
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
471
+
472
+ @triton.jit
473
+ def bwd_dq_inner(
474
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
475
+ K, V, # pointers
476
+ dq, q, do, Di, lse,
477
+ off_z, off_hq, offs_m2, offs_n2,
478
+ stride_kn, stride_kd, stride_vn, stride_vd,
479
+ kv_indices, sparse_kv_num_blocks,
480
+ MATMUL_PRECISION,
481
+ IS_FULL_BLOCKS,
482
+ ):
483
+ PRESCALE_QK : tl.constexpr = False
484
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
485
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
486
+ WRITE_DQ : tl.constexpr = True
487
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
488
+ OUTPUT_MAX : tl.constexpr = False
489
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
490
+ IS_DIVISIBLE : tl.constexpr = False
491
+ SM_SCALE : tl.constexpr = 0.08838834764831845
492
+ GQA_SHARED_HEADS : tl.constexpr = 4
493
+ HAS_FULL_BLOCKS : tl.constexpr = True
494
+ QK_HEAD_DIM : tl.constexpr = 128
495
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
496
+ V_HEAD_DIM : tl.constexpr = 128
497
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
498
+ SAFE_HEAD_DIM : tl.constexpr = True
499
+ BLOCK_M1 : tl.constexpr = 64
500
+ BLOCK_N1 : tl.constexpr = 128
501
+ BLOCK_M2 : tl.constexpr = 128
502
+ BLOCK_N2 : tl.constexpr = 64
503
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
504
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
505
+ INDEX_DTYPE : tl.constexpr = tl.int32
506
+
507
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
508
+ RCP_LN2: tl.constexpr = 1.44269504
509
+ Q_LEN = 1712
510
+ KV_LEN = 1712
511
+
512
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
513
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
514
+
515
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
516
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
517
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
518
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
519
+
520
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
521
+
522
+ for start_n in range(0, hi):
523
+ dq = bwd_dq_block_mn(
524
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
525
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
526
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
527
+ stride_kn, stride_kd, stride_vn, stride_vd,
528
+ kv_indices, sparse_kv_num_blocks,
529
+ MATMUL_PRECISION, RCP_LN2,
530
+ IS_FULL_BLOCKS,
531
+ )
532
+
533
+ # Increment pointers.
534
+ offset = get_offset_for_next_block(
535
+ start_n, kv_indices, sparse_kv_num_blocks,
536
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
537
+ )
538
+
539
+ kT_ptrs += offset * stride_kn
540
+ vT_ptrs += offset * stride_vn
541
+
542
+ offs_n2 += offset
543
+
544
+ return dq
545
+
546
+
547
+ @triton.jit
548
+ def bwd_dq_block_mn(
549
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
550
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
551
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
552
+ stride_kn, stride_kd, stride_vn, stride_vd,
553
+ kv_indices, sparse_kv_num_blocks,
554
+ MATMUL_PRECISION, RCP_LN2,
555
+ IS_FULL_BLOCKS,
556
+ ):
557
+ PRESCALE_QK : tl.constexpr = False
558
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
559
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
560
+ WRITE_DQ : tl.constexpr = True
561
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
562
+ OUTPUT_MAX : tl.constexpr = False
563
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
564
+ IS_DIVISIBLE : tl.constexpr = False
565
+ SM_SCALE : tl.constexpr = 0.08838834764831845
566
+ GQA_SHARED_HEADS : tl.constexpr = 4
567
+ HAS_FULL_BLOCKS : tl.constexpr = True
568
+ QK_HEAD_DIM : tl.constexpr = 128
569
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
570
+ V_HEAD_DIM : tl.constexpr = 128
571
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
572
+ SAFE_HEAD_DIM : tl.constexpr = True
573
+ BLOCK_M1 : tl.constexpr = 64
574
+ BLOCK_N1 : tl.constexpr = 128
575
+ BLOCK_M2 : tl.constexpr = 128
576
+ BLOCK_N2 : tl.constexpr = 64
577
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
578
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
579
+ INDEX_DTYPE : tl.constexpr = tl.int32
580
+
581
+
582
+ # NB reversed order to since K is transposed
583
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
584
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
585
+ if not PRESCALE_QK:
586
+ qk *= SM_SCALE
587
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
588
+ pre_mod_scores = qk
589
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
590
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
591
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
592
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
593
+
594
+ tmp0 = (qk)
595
+ post_mod_scores = tmp0
596
+
597
+
598
+
599
+
600
+ if not IS_DIVISIBLE:
601
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
602
+
603
+ if not IS_FULL_BLOCKS:
604
+ tmp1 = (m)
605
+ tmp2 = tl.full([1], 0, tl.int32)
606
+ tmp3 = tmp1 < tmp2
607
+ tmp4 = (n)
608
+ tmp5 = tmp4 <= tmp1
609
+ tmp6 = tmp3 & tmp5
610
+ tmp7 = tmp1 >= tmp2
611
+ tmp8 = tmp4 < tmp2
612
+ tmp9 = tmp7 & tmp8
613
+ tmp10 = tmp8 == 0
614
+ tmp11 = tmp7 & tmp10
615
+ tmp12 = tmp1 - tmp2
616
+ tmp13 = tl.full([1], 16, tl.int32)
617
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
618
+ tmp15 = tmp4 - tmp2
619
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
620
+ tmp17 = tmp14 == tmp16
621
+ tmp18 = tmp11 & tmp17
622
+ tmp19 = tmp9 | tmp18
623
+ tmp20 = tmp6 | tmp19
624
+ mask_mod_output = tmp20
625
+
626
+
627
+ # apply mask for partial masked block
628
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
629
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
630
+ if not PRESCALE_QK:
631
+ post_mod_scores *= RCP_LN2
632
+ p = tl.math.exp2(post_mod_scores - lse)
633
+ # Compute dP and dS.
634
+ # NB reversed order to since V is transposed
635
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
636
+
637
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
638
+ ds = p * (dp - Di[:, None])
639
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
640
+ tmp21 = (ds)
641
+ grad_scores = tmp21
642
+
643
+
644
+ if not IS_DIVISIBLE:
645
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
646
+
647
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
648
+ if WRITE_DQ:
649
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
650
+
651
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
652
+ ds = grad_scores
653
+
654
+ if not IS_FULL_BLOCKS:
655
+ # (grads) apply mask for partially unmasked block
656
+ ds = tl.where(mask_mod_output, ds, 0.0)
657
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658
+ ds = ds.to(MATMUL_PRECISION)
659
+ # Compute dQ.
660
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
661
+
662
+ return dq
663
+
664
+
665
+ @triton.jit
666
+ def bwd_dkdv_inner(
667
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
668
+ Q, DO, DELTA, LSE, # pointers
669
+ dk, dv, k, v,
670
+ off_z, off_hq, offs_n1, offs_m1,
671
+ stride_qm, stride_qd, stride_dom, stride_dod,
672
+ q_indices, sparse_q_num_blocks,
673
+ MATMUL_PRECISION,
674
+ IS_FULL_BLOCKS,
675
+ ):
676
+ PRESCALE_QK : tl.constexpr = False
677
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
678
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
679
+ WRITE_DQ : tl.constexpr = True
680
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
681
+ OUTPUT_MAX : tl.constexpr = False
682
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
683
+ IS_DIVISIBLE : tl.constexpr = False
684
+ SM_SCALE : tl.constexpr = 0.08838834764831845
685
+ GQA_SHARED_HEADS : tl.constexpr = 4
686
+ HAS_FULL_BLOCKS : tl.constexpr = True
687
+ QK_HEAD_DIM : tl.constexpr = 128
688
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
689
+ V_HEAD_DIM : tl.constexpr = 128
690
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
691
+ SAFE_HEAD_DIM : tl.constexpr = True
692
+ BLOCK_M1 : tl.constexpr = 64
693
+ BLOCK_N1 : tl.constexpr = 128
694
+ BLOCK_M2 : tl.constexpr = 128
695
+ BLOCK_N2 : tl.constexpr = 64
696
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
697
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
698
+ INDEX_DTYPE : tl.constexpr = tl.int32
699
+
700
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
701
+ RCP_LN2: tl.constexpr = 1.44269504
702
+ Q_LEN = 1712
703
+ KV_LEN = 1712
704
+
705
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
706
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
707
+
708
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
709
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
710
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
711
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
712
+
713
+ # The minimum is needed to handle the case where we run with a super large
714
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
715
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
716
+
717
+ for start_m in range(0, hi):
718
+ dk, dv = bwd_dkdv_block_mn(
719
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
720
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
721
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
722
+ stride_qm, stride_qd, stride_dom, stride_dod,
723
+ q_indices, sparse_q_num_blocks,
724
+ MATMUL_PRECISION, RCP_LN2,
725
+ IS_FULL_BLOCKS,
726
+ )
727
+ # Increment pointers.
728
+ offset = get_offset_for_next_block(
729
+ start_m, q_indices, sparse_q_num_blocks,
730
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
731
+ )
732
+
733
+ qT_ptrs += offset * stride_qm
734
+ do_ptrs += offset * stride_dom
735
+ offs_m1 += offset
736
+
737
+ return dk, dv
738
+
739
+
740
+ @triton.jit
741
+ def bwd_dkdv_block_mn(
742
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
743
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
744
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
745
+ stride_qm, stride_qd, stride_dom, stride_dod,
746
+ q_indices, sparse_q_num_blocks,
747
+ MATMUL_PRECISION, RCP_LN2,
748
+ IS_FULL_BLOCKS,
749
+ ):
750
+ PRESCALE_QK : tl.constexpr = False
751
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
752
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
753
+ WRITE_DQ : tl.constexpr = True
754
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
755
+ OUTPUT_MAX : tl.constexpr = False
756
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
757
+ IS_DIVISIBLE : tl.constexpr = False
758
+ SM_SCALE : tl.constexpr = 0.08838834764831845
759
+ GQA_SHARED_HEADS : tl.constexpr = 4
760
+ HAS_FULL_BLOCKS : tl.constexpr = True
761
+ QK_HEAD_DIM : tl.constexpr = 128
762
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
763
+ V_HEAD_DIM : tl.constexpr = 128
764
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
765
+ SAFE_HEAD_DIM : tl.constexpr = True
766
+ BLOCK_M1 : tl.constexpr = 64
767
+ BLOCK_N1 : tl.constexpr = 128
768
+ BLOCK_M2 : tl.constexpr = 128
769
+ BLOCK_N2 : tl.constexpr = 64
770
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
771
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
772
+ INDEX_DTYPE : tl.constexpr = tl.int32
773
+
774
+
775
+ # NB reversed order since Q is transposed
776
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
777
+ # Load LSE before computing qk to reduce pipeline stall.
778
+ if IS_DIVISIBLE:
779
+ lse = tl.load(LSE + offs_m1)
780
+ else:
781
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
782
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
783
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
784
+ if not PRESCALE_QK:
785
+ qkT *= SM_SCALE
786
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
787
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
788
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
789
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
790
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
791
+
792
+ pre_mod_scores = qkT
793
+ tmp22 = (qkT)
794
+ post_mod_scores = tmp22
795
+
796
+
797
+
798
+ if not IS_DIVISIBLE:
799
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
800
+
801
+ if not IS_FULL_BLOCKS:
802
+ tmp23 = (m)
803
+ tmp24 = tl.full([1], 0, tl.int32)
804
+ tmp25 = tmp23 < tmp24
805
+ tmp26 = (n)
806
+ tmp27 = tmp26 <= tmp23
807
+ tmp28 = tmp25 & tmp27
808
+ tmp29 = tmp23 >= tmp24
809
+ tmp30 = tmp26 < tmp24
810
+ tmp31 = tmp29 & tmp30
811
+ tmp32 = tmp30 == 0
812
+ tmp33 = tmp29 & tmp32
813
+ tmp34 = tmp23 - tmp24
814
+ tmp35 = tl.full([1], 16, tl.int32)
815
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
816
+ tmp37 = tmp26 - tmp24
817
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
818
+ tmp39 = tmp36 == tmp38
819
+ tmp40 = tmp33 & tmp39
820
+ tmp41 = tmp31 | tmp40
821
+ tmp42 = tmp28 | tmp41
822
+ mask_mod_output = tmp42
823
+
824
+ # (grads) apply mask for fully masked block
825
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
826
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
827
+ if not PRESCALE_QK:
828
+ post_mod_scores *= RCP_LN2
829
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
830
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
831
+ # Compute dV.
832
+ ppT = pT
833
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
834
+ if IS_DIVISIBLE:
835
+ Di = tl.load(DELTA + offs_m1)
836
+ else:
837
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
838
+ # Compute dP and dS.
839
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
840
+ dsT = pT * (dpT - Di[None, :])
841
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
842
+ tmp43 = (dsT)
843
+ grad_scores = tmp43
844
+
845
+
846
+
847
+ if not IS_DIVISIBLE:
848
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
849
+
850
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
851
+ if not WRITE_DQ:
852
+ idx_b = off_z
853
+ idx_h = off_hq
854
+ idx_m = m
855
+ idx_n = n
856
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
857
+
858
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
859
+ dsT = grad_scores
860
+ if not IS_FULL_BLOCKS:
861
+ # (grads) apply mask for partially unmasked block
862
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
863
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
864
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
865
+
866
+ return dk, dv
867
+
868
+ # Utility triton funcs
869
+ @triton.jit
870
+ def get_offset_for_next_block(
871
+ loop_iter, col_indices, total_blocks,
872
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
873
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
874
+ ):
875
+ if BLOCKS_ARE_CONTIGUOUS:
876
+ return BLOCK
877
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
878
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
879
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
880
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
881
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
882
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
883
+ return offset
884
+
885
+ @triton.jit
886
+ def get_bounded_indices(indices, max_len=None):
887
+ return indices % max_len if max_len is not None else indices
888
+
889
+ @triton.jit
890
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
891
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
892
+ return tl.load(block_ptr)
893
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
894
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
895
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
896
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
897
+ else:
898
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
899
+
900
+ @triton.jit
901
+ def load_checked_2d(
902
+ ptr,
903
+ offs_m,
904
+ offs_n,
905
+ stride_m,
906
+ stride_n,
907
+ IS_DIVISIBLE_M: tl.constexpr,
908
+ IS_DIVISIBLE_N: tl.constexpr,
909
+ M_LEN: tl.constexpr,
910
+ N_LEN: tl.constexpr,
911
+ ):
912
+ # Calculate final pointer if strides are provided
913
+ if stride_m is not None and stride_n is not None:
914
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
915
+
916
+ # Handle all masking cases
917
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
918
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
919
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
920
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
921
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
922
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
923
+ else: # Both divisible
924
+ return tl.load(ptr)
925
+ ''', device_str='cuda')
926
+
927
+
928
+ async_compile.wait(globals())
929
+ del async_compile
930
+
931
+ class Runner:
932
+ def __init__(self, partitions):
933
+ self.partitions = partitions
934
+
935
+ def recursively_apply_fns(self, fns):
936
+ new_callables = []
937
+ for fn, c in zip(fns, self.partitions):
938
+ new_callables.append(fn(c))
939
+ self.partitions = new_callables
940
+
941
+ def call(self, args):
942
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args
943
+ args.clear()
944
+ assert_size_stride(primals_1, (1, 32, 1712, 128), (7012352, 128, 4096, 1))
945
+ assert_size_stride(primals_2, (1, 8, 1712, 128), (1753088, 128, 1024, 1))
946
+ assert_size_stride(primals_3, (1, 8, 1712, 128), (1753088, 128, 1024, 1))
947
+ assert_size_stride(primals_4, (1, 1, 14, 14), (196, 196, 14, 1))
948
+ assert_size_stride(primals_5, (1, 1, 14), (14, 14, 1))
949
+ assert_size_stride(primals_6, (1, 1, 14), (14, 14, 1))
950
+ assert_size_stride(primals_7, (1, 1, 14, 14), (196, 196, 14, 1))
951
+ assert_size_stride(primals_8, (1, 1, 14), (14, 14, 1))
952
+ assert_size_stride(primals_9, (1, 1, 14, 14), (196, 196, 14, 1))
953
+ assert_size_stride(primals_10, (1, 1, 14), (14, 14, 1))
954
+ assert_size_stride(primals_11, (1, 1, 14, 14), (196, 196, 14, 1))
955
+ assert_size_stride(getitem, (1, 32, 1712, 128), (7012352, 128, 4096, 1))
956
+ assert_size_stride(getitem_1, (1, 32, 1712), (54784, 1712, 1))
957
+ assert_size_stride(tangents_1, (1, 32, 1712, 128), (7012352, 219136, 128, 1))
958
+ assert_size_stride(tangents_2, (1, 32, 1712), (54784, 1712, 1))
959
+ with torch.cuda._DeviceGuard(2):
960
+ torch.cuda.set_device(2)
961
+ buf1 = empty_strided_cuda((1, 32, 1712), (54784, 1712, 1), torch.float32)
962
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
963
+ stream2 = get_raw_stream(2)
964
+ triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 54784, 128, stream=stream2)
965
+ del getitem
966
+ del tangents_2
967
+ buf3 = empty_strided_cuda((1, 32, 1712, 128), (7012352, 128, 4096, 1), torch.bfloat16)
968
+ buf4 = empty_strided_cuda((1, 8, 1712, 128), (1753088, 128, 1024, 1), torch.bfloat16)
969
+ buf5 = empty_strided_cuda((1, 8, 1712, 128), (1753088, 128, 1024, 1), torch.bfloat16)
970
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
971
+ stream2 = get_raw_stream(2)
972
+ triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 70, 1, 8, stream=stream2)
973
+ del buf1
974
+ del getitem_1
975
+ del primals_1
976
+ del primals_10
977
+ del primals_11
978
+ del primals_2
979
+ del primals_3
980
+ del primals_4
981
+ del primals_5
982
+ del primals_6
983
+ del primals_7
984
+ del primals_8
985
+ del primals_9
986
+ del tangents_1
987
+ return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, )
988
+
989
+ runner = Runner(partitions=[])
990
+ call = runner.call
991
+ recursively_apply_fns = runner.recursively_apply_fns
992
+
993
+
994
+ def benchmark_compiled_module(times=10, repeat=10):
995
+ from torch._dynamo.testing import rand_strided
996
+ from torch._inductor.utils import print_performance
997
+ primals_1 = rand_strided((1, 32, 1712, 128), (7012352, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
998
+ primals_2 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
999
+ primals_3 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
1000
+ primals_4 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
1001
+ primals_5 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
1002
+ primals_6 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
1003
+ primals_7 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
1004
+ primals_8 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
1005
+ primals_9 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
1006
+ primals_10 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
1007
+ primals_11 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
1008
+ getitem = rand_strided((1, 32, 1712, 128), (7012352, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
1009
+ getitem_1 = rand_strided((1, 32, 1712), (54784, 1712, 1), device='cuda:2', dtype=torch.float32)
1010
+ tangents_1 = rand_strided((1, 32, 1712, 128), (7012352, 219136, 128, 1), device='cuda:2', dtype=torch.bfloat16)
1011
+ tangents_2 = rand_strided((1, 32, 1712), (54784, 1712, 1), device='cuda:2', dtype=torch.float32)
1012
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2])
1013
+ return print_performance(fn, times=times, repeat=repeat)
1014
+
1015
+
1016
+ if __name__ == "__main__":
1017
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1018
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2x/c2xecscuz5jhvznv7jn4k545b7kcexuko5lz3em6woeo7u2ftonz.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 2048, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ x2 = (xindex % ks0)
34
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
35
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
36
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
39
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
40
+ tmp6 = float("-inf")
41
+ tmp7 = tmp4 == tmp6
42
+ tmp8 = tmp0 - tmp4
43
+ tmp9 = 0.0
44
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
45
+ tmp11 = libdevice.exp2(tmp10)
46
+ tmp12 = tmp5 * tmp11
47
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
48
+ tmp15 = tl.where(xmask, tmp13, 0)
49
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
50
+ tmp17 = 1.0
51
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
52
+ tmp19 = libdevice.log2(tmp18)
53
+ tmp20 = tmp19 + tmp4
54
+ tmp21 = 0.6931471805599453
55
+ tmp22 = tmp20 * tmp21
56
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
57
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
58
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
progress/SpecForge/cache/compiled_kernels/2x/c2ximikyisa7xxnki36flzcsdr4ziwruq7ujf3zymsuxon5pqv57.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ue/cueevzv7vmofb7xgliazywafcctru3ytzoxylqvshvpe6nweecz6.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2]
43
+ # %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4]
44
+ # %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6]
45
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=getitem_1]
46
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1]
47
+ # %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_10]
48
+ # %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_7]
49
+ # %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_11]
50
+ # %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_12]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %getitem
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=8,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ SM_SCALE : tl.constexpr = 0.08838834764831845
80
+ GQA_SHARED_HEADS : tl.constexpr = 4
81
+ HAS_FULL_BLOCKS : tl.constexpr = True
82
+ QK_HEAD_DIM : tl.constexpr = 128
83
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
84
+ V_HEAD_DIM : tl.constexpr = 128
85
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ SAFE_HEAD_DIM : tl.constexpr = True
87
+ USE_TMA : tl.constexpr = False
88
+ BLOCK_M : tl.constexpr = 128
89
+ BLOCK_N : tl.constexpr = 64
90
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
91
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
92
+ INDEX_DTYPE : tl.constexpr = tl.int32
93
+ Q = arg_Q
94
+ K = arg_K
95
+ V = arg_V
96
+ LSE = arg_LSE
97
+ MAX = arg_MAX
98
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
99
+ KV_IDX = arg_KV_IDX
100
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
101
+ FULL_KV_IDX = arg_FULL_KV_IDX
102
+
103
+ # Sub notation for this kernel:
104
+ #
105
+ # Q: Query, K: Key, V: Value
106
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
107
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
108
+ # V_HEAD_DIM: The dimension of the value embeddings
109
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
110
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
111
+ #
112
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
113
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
114
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
115
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
116
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
117
+ #
118
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
119
+ #
120
+ # (Modifiable) Performance tuning options
121
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
122
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
123
+
124
+ # The below are kernel options that can be applied for certain score_mods,
125
+ # or involve a numerics vs. perf tradeoff
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
127
+ # about 20% more numerical error, but slightly faster.
128
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
129
+ # is not masked out? If so, we can skip an extra safety check
130
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
131
+ # contiguous? If so, we don't need to do an indirect jump for every block
132
+
133
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
134
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
135
+
136
+ # Define strides of inputs
137
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
138
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
139
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
140
+
141
+ ZQ = 1
142
+ HQ = 32
143
+ Q_LEN = ks0
144
+ ZKV = 1
145
+ KV_LEN = ks1
146
+
147
+ MATMUL_PRECISION = Q.dtype.element_ty
148
+
149
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
150
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
151
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
152
+
153
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
154
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
155
+ off_zkv = off_zq % ZKV
156
+ off_hkv = off_hq // GQA_SHARED_HEADS
157
+ off_g = off_hq % GQA_SHARED_HEADS
158
+
159
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
160
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
161
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
162
+
163
+ Q = Q + q_offset
164
+ K = K + k_offset
165
+ V = V + v_offset
166
+
167
+ # Setting up the TMA descriptors for Q, K, V
168
+ desc_q = None
169
+ desc_k = None
170
+ desc_v = None
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_zq % SPARSE_Z
176
+ sparse_idx_hq = off_hq % SPARSE_HQ
177
+
178
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
179
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
180
+
181
+ stride_kv_num_blks_h = 1
182
+ stride_kv_idx_h = 1
183
+ stride_kv_idx_m = 1
184
+
185
+ # initialize pointer to m and l
186
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
187
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
188
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
189
+
190
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
191
+
192
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
193
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
194
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
195
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
196
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
197
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
198
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
199
+
200
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201
+ # We don't know anything "special" about these blocks, so we need to apply
202
+ # both score_mod and mask_mod to it
203
+ kv_indices = KV_IDX + sparse_kv_idx_offset
204
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
205
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
206
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
207
+
208
+
209
+ # K and V pointers will be passed directly to forward_inner
210
+
211
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
212
+
213
+
214
+ acc, l_i, m_i = forward_inner(
215
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
216
+ q, K, V,
217
+ desc_k, desc_v, Q_LEN, KV_LEN,
218
+ acc, l_i, m_i,
219
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
220
+ kv_start,
221
+ kv_indices, kv_num_blocks,
222
+ 0, block_n_end,
223
+ MATMUL_PRECISION,
224
+ stride_kk, stride_kn, stride_vn, stride_vk,
225
+ IS_FULL_BLOCKS=False,
226
+ )
227
+
228
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229
+ # We know these blocks are guaranteed to be "full", so we don't need to
230
+ # apply mask_mod to them - only score_mod
231
+ if HAS_FULL_BLOCKS:
232
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
233
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
234
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
235
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
236
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
237
+ # K and V pointers will be passed directly to forward_inner
238
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
242
+ q, K, V,
243
+ desc_k, desc_v, Q_LEN, KV_LEN,
244
+ acc, l_i, m_i,
245
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
246
+ kv_start,
247
+ kv_indices, kv_num_blocks,
248
+ 0, block_n_end,
249
+ MATMUL_PRECISION,
250
+ stride_kk, stride_kn, stride_vn, stride_vk,
251
+ IS_FULL_BLOCKS=True,
252
+ )
253
+
254
+
255
+ # [Note] Handle fully masked out rows:
256
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
257
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
258
+ l_i = tl.where(l_i == 0.0, 1, l_i)
259
+
260
+ acc = acc / l_i[:, None]
261
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
262
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
263
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
264
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
265
+
266
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
267
+
268
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
269
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
270
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
271
+
272
+ if OUTPUT_LOGSUMEXP:
273
+ off_hz = off_zq * HQ + off_hq
274
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
275
+ lse = m_i + tl.math.log2(l_i)
276
+ if IS_DIVISIBLE:
277
+ tl.store(l_ptrs, lse)
278
+ else:
279
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
280
+
281
+ if OUTPUT_MAX:
282
+ off_hz = off_zq * HQ + off_hq
283
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
284
+ if IS_DIVISIBLE:
285
+ tl.store(max_ptrs, m_i)
286
+ else:
287
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
288
+
289
+
290
+ # Utility triton funcs
291
+ @triton.jit
292
+ def get_offset_for_next_block(
293
+ loop_iter, col_indices, total_blocks,
294
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
295
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
296
+ ):
297
+ if BLOCKS_ARE_CONTIGUOUS:
298
+ return BLOCK
299
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
300
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
301
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
302
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
303
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
304
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
305
+ return offset
306
+
307
+ @triton.jit
308
+ def get_bounded_indices(indices, max_len=None):
309
+ return indices % max_len if max_len is not None else indices
310
+
311
+ @triton.jit
312
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
313
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
314
+ return tl.load(block_ptr)
315
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
317
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
319
+ else:
320
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
321
+
322
+ @triton.jit
323
+ def load_checked_2d(
324
+ ptr,
325
+ offs_m,
326
+ offs_n,
327
+ stride_m,
328
+ stride_n,
329
+ IS_DIVISIBLE_M: tl.constexpr,
330
+ IS_DIVISIBLE_N: tl.constexpr,
331
+ M_LEN: tl.constexpr,
332
+ N_LEN: tl.constexpr,
333
+ ):
334
+ # Calculate final pointer if strides are provided
335
+ if stride_m is not None and stride_n is not None:
336
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
337
+
338
+ # Handle all masking cases
339
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
340
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
341
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
343
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
345
+ else: # Both divisible
346
+ return tl.load(ptr)
347
+
348
+
349
+ # Common Imports
350
+ @triton.jit
351
+ def forward_block_mn(
352
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
353
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
354
+ # accumulated values
355
+ acc, l_i, m_i,
356
+ # Offsets
357
+ off_z, off_h, offs_m, offs_n,
358
+ # Offsets needed for TMA loads
359
+ kv_start,
360
+ kv_offset,
361
+ MATMUL_PRECISION, RCP_LN2,
362
+ # Strides for K and V
363
+ stride_kk, stride_kn, stride_vn, stride_vk,
364
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
365
+
366
+ ):
367
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
368
+ PRESCALE_QK : tl.constexpr = False
369
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
370
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
371
+ WRITE_DQ : tl.constexpr = True
372
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
373
+ OUTPUT_MAX : tl.constexpr = False
374
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
375
+ IS_DIVISIBLE : tl.constexpr = False
376
+ SM_SCALE : tl.constexpr = 0.08838834764831845
377
+ GQA_SHARED_HEADS : tl.constexpr = 4
378
+ HAS_FULL_BLOCKS : tl.constexpr = True
379
+ QK_HEAD_DIM : tl.constexpr = 128
380
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
381
+ V_HEAD_DIM : tl.constexpr = 128
382
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ SAFE_HEAD_DIM : tl.constexpr = True
384
+ USE_TMA : tl.constexpr = False
385
+ BLOCK_M : tl.constexpr = 128
386
+ BLOCK_N : tl.constexpr = 64
387
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
388
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
389
+ INDEX_DTYPE : tl.constexpr = tl.int32
390
+
391
+
392
+ # -- load k --
393
+ # NB reversed order to since K is transposed
394
+ kv_base_offset = kv_start + kv_offset
395
+
396
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
397
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
398
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
399
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
400
+
401
+ k = tl.trans(k)
402
+ # -- compute qk ---
403
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
404
+ if not PRESCALE_QK:
405
+ qk *= SM_SCALE
406
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
407
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
408
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
409
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
410
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
411
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
412
+
413
+ tmp0 = (qk)
414
+ post_mod_scores = tmp0
415
+
416
+
417
+ if CHECK_BLOCK_BOUNDARY:
418
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
419
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
420
+
421
+ if not IS_FULL_BLOCKS:
422
+ tmp1 = (m)
423
+ tmp2 = tl.full([1], 0, tl.int32)
424
+ tmp3 = tmp1 < tmp2
425
+ tmp4 = (n)
426
+ tmp5 = tmp4 <= tmp1
427
+ tmp6 = tmp3 & tmp5
428
+ tmp7 = tmp1 >= tmp2
429
+ tmp8 = tmp4 < tmp2
430
+ tmp9 = tmp7 & tmp8
431
+ tmp10 = tmp8 == 0
432
+ tmp11 = tmp7 & tmp10
433
+ tmp12 = tmp1 - tmp2
434
+ tmp13 = tl.full([1], 16, tl.int32)
435
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
436
+ tmp15 = tmp4 - tmp2
437
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
438
+ tmp17 = tmp14 == tmp16
439
+ tmp18 = tmp11 & tmp17
440
+ tmp19 = tmp9 | tmp18
441
+ tmp20 = tmp6 | tmp19
442
+ mask_mod_output = tmp20
443
+
444
+
445
+ if CHECK_BLOCK_BOUNDARY:
446
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
447
+ # apply mask for partially unmasked blocks
448
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
449
+
450
+ if not PRESCALE_QK:
451
+ post_mod_scores *= RCP_LN2
452
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
453
+
454
+ # -- compute scaling constant ---
455
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
456
+ if not ROWS_GUARANTEED_SAFE:
457
+ masked_out_rows = (m_ij == float("-inf"))
458
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
459
+ else:
460
+ m_ij_masked = m_ij
461
+
462
+ alpha = tl.math.exp2(m_i - m_ij_masked)
463
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
464
+
465
+ # NB: l_i update is pulled up here since it's a bit faster
466
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
467
+ # m_ij
468
+ l_i = l_i * alpha + tl.sum(p, 1)
469
+ # # -- scale and update acc --
470
+ acc = acc * alpha[:, None]
471
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
472
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
473
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
474
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
475
+
476
+ # -- update m_i
477
+ m_i = m_ij
478
+
479
+ return acc, l_i, m_i
480
+
481
+ @triton.jit
482
+ def forward_inner(
483
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
484
+ q, K, V,
485
+ desc_k, desc_v, Q_LEN, KV_LEN,
486
+ # accumulated values
487
+ acc, l_i, m_i,
488
+ # Offsets used as inputs to score_mod & mask_mod
489
+ # of size [BLOCK_M, BLOCK_N] or scalar.
490
+ off_z, off_h, offs_m, offs_n,
491
+ # Offsets needed for TMA loads
492
+ kv_start,
493
+ # blocksparse data
494
+ kv_indices, kv_num_blocks,
495
+ # start kv and end kv block
496
+ block_n_start, block_n_end,
497
+ MATMUL_PRECISION,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ ):
502
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
503
+ PRESCALE_QK : tl.constexpr = False
504
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
505
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
506
+ WRITE_DQ : tl.constexpr = True
507
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
508
+ OUTPUT_MAX : tl.constexpr = False
509
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
510
+ IS_DIVISIBLE : tl.constexpr = False
511
+ SM_SCALE : tl.constexpr = 0.08838834764831845
512
+ GQA_SHARED_HEADS : tl.constexpr = 4
513
+ HAS_FULL_BLOCKS : tl.constexpr = True
514
+ QK_HEAD_DIM : tl.constexpr = 128
515
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
516
+ V_HEAD_DIM : tl.constexpr = 128
517
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
518
+ SAFE_HEAD_DIM : tl.constexpr = True
519
+ USE_TMA : tl.constexpr = False
520
+ BLOCK_M : tl.constexpr = 128
521
+ BLOCK_N : tl.constexpr = 64
522
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
523
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
524
+ INDEX_DTYPE : tl.constexpr = tl.int32
525
+
526
+
527
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
528
+ RCP_LN2: tl.constexpr = 1.44269504
529
+
530
+ if PRESCALE_QK:
531
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
532
+
533
+ kv_offset = 0
534
+
535
+ # loop over k, v and update accumulator until block_n_end
536
+ for start_n in range(block_n_start, block_n_end):
537
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
538
+ if IS_DIVISIBLE:
539
+ acc, l_i, m_i = forward_block_mn(
540
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
541
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
542
+ # accumulated values
543
+ acc, l_i, m_i,
544
+ # Offsets
545
+ off_z, off_h, offs_m, offs_n,
546
+ # Offsets needed for TMA loads
547
+ kv_start,
548
+ kv_offset,
549
+ MATMUL_PRECISION, RCP_LN2,
550
+ # Strides for K and V
551
+ stride_kk, stride_kn, stride_vn, stride_vk,
552
+ IS_FULL_BLOCKS,
553
+ )
554
+ else:
555
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
556
+ # it's on par or slightly faster than only applying to the last block in fwd.
557
+ # However, we choose different strategy for bwd, where we only apply mod & mask
558
+ # to the last block because it's faster a lot.
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
573
+ )
574
+
575
+
576
+
577
+ offset = get_offset_for_next_block(
578
+ start_n, kv_indices, kv_num_blocks,
579
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
580
+ )
581
+
582
+ offs_n = offs_n + offset
583
+ kv_offset += offset
584
+
585
+
586
+ return acc, l_i, m_i
587
+ ''', device_str='cuda')
588
+
589
+
590
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/iv/civfnuj45ifcjf545fyz4ryg2q422dh2larygr2jr5yyli7gz6ae.py
591
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
592
+ # Source node to ATen node mapping:
593
+ # lse_scaled => mul_15
594
+ # Graph fragment:
595
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
596
+ # %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
597
+ # return %mul_15
598
+ triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
599
+ import triton
600
+ import triton.language as tl
601
+
602
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
603
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
604
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
605
+ triton_helpers.set_driver_to_gpu()
606
+
607
+ @triton_heuristics.pointwise(
608
+ size_hints={'x': 4096},
609
+ filename=__file__,
610
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
611
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
612
+ min_elem_per_thread=0
613
+ )
614
+ @triton.jit
615
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
616
+ xoffset = tl.program_id(0) * XBLOCK
617
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
618
+ xmask = xindex < xnumel
619
+ x2 = xindex
620
+ x0 = (xindex % ks0)
621
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
622
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
623
+ tmp1 = 0.6931471805599453
624
+ tmp2 = tmp0 * tmp1
625
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
626
+ ''', device_str='cuda')
627
+
628
+
629
+ async_compile.wait(globals())
630
+ del async_compile
631
+
632
+ class Runner:
633
+ def __init__(self, partitions):
634
+ self.partitions = partitions
635
+
636
+ def recursively_apply_fns(self, fns):
637
+ new_callables = []
638
+ for fn, c in zip(fns, self.partitions):
639
+ new_callables.append(fn(c))
640
+ self.partitions = new_callables
641
+
642
+ def call(self, args):
643
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args
644
+ args.clear()
645
+ s50 = primals_1
646
+ s0 = primals_3
647
+ s43 = primals_5
648
+ s37 = primals_8
649
+ s71 = primals_9
650
+ assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
651
+ assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
652
+ assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
653
+ assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
654
+ assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
655
+ assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
656
+ assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
657
+ assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
658
+ assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
659
+ assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
660
+ assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
661
+ with torch.cuda._DeviceGuard(3):
662
+ torch.cuda.set_device(3)
663
+ buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
664
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
665
+ buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
666
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
667
+ stream3 = get_raw_stream(3)
668
+ triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream3)
669
+ del buf1
670
+ buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
671
+ # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
672
+ triton_poi_fused_mul_1_xnumel = 32*s37
673
+ stream3 = get_raw_stream(3)
674
+ triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream3)
675
+ return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, )
676
+
677
+ runner = Runner(partitions=[])
678
+ call = runner.call
679
+ recursively_apply_fns = runner.recursively_apply_fns
680
+
681
+
682
+ def benchmark_compiled_module(times=10, repeat=10):
683
+ from torch._dynamo.testing import rand_strided
684
+ from torch._inductor.utils import print_performance
685
+ primals_1 = 128
686
+ primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
687
+ primals_3 = 128
688
+ primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
689
+ primals_5 = 128
690
+ primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
691
+ primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
692
+ primals_8 = 128
693
+ primals_9 = 128
694
+ primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
695
+ primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
696
+ primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
697
+ primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
698
+ primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
699
+ primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
700
+ primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
701
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16])
702
+ return print_performance(fn, times=times, repeat=repeat)
703
+
704
+
705
+ if __name__ == "__main__":
706
+ from torch._inductor.wrapper_benchmark import compiled_module_main
707
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/2x/dde0479cca0d878e6e0800ec13f7c80962354e837542bfc5f11f7b49306d323e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"}
progress/SpecForge/cache/compiled_kernels/3h/6ee97c795357f97e7127237e15db9bd5fb14510b837eeb5094115cfaa1802d32.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
progress/SpecForge/cache/compiled_kernels/3h/c3h3fb5vqykgr7s3powfrnsc5alooplbijdgjizqo3xq5psavrvz.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4096},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x0 = (xindex % ks0)
24
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
25
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
26
+ tmp1 = 0.6931471805599453
27
+ tmp2 = tmp0 * tmp1
28
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/3h/c3hro2ygwh2ixqhmbrrdsjq6biaehv6lm5cbeo6yhlo6ssqkwpha.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 1
148
+ stride_kv_idx_h = 1
149
+ stride_kv_idx_m = 1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 1
245
+ stride_q_idx_h = 1
246
+ stride_q_idx_n = 1
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/3p/c3pdrhexk4rwol7f5l5vh7n543dj6piq6gw5k66g2p4vlyhopnop.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 2048, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ x2 = (xindex % ks0)
34
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
35
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
36
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
39
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
40
+ tmp6 = float("-inf")
41
+ tmp7 = tmp4 == tmp6
42
+ tmp8 = tmp0 - tmp4
43
+ tmp9 = 0.0
44
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
45
+ tmp11 = libdevice.exp2(tmp10)
46
+ tmp12 = tmp5 * tmp11
47
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
48
+ tmp15 = tl.where(xmask, tmp13, 0)
49
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
50
+ tmp17 = 1.0
51
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
52
+ tmp19 = libdevice.log2(tmp18)
53
+ tmp20 = tmp19 + tmp4
54
+ tmp21 = 0.6931471805599453
55
+ tmp22 = tmp20 * tmp21
56
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
57
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
58
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
progress/SpecForge/cache/compiled_kernels/3p/d4e91f4bc49d9cfc59a03caa3a2e04988f99c358762e8d23eed306dbbe3eae25.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 69, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"}
progress/SpecForge/cache/compiled_kernels/3s/c3spq2k2yeawxvgwl4dczrad6qwkidiiyxz5xwsucqivwlx625g7.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/3u/c3umapah7vcozhvfk5uovlssor7v533y4crphqgd677nuoizbpvj.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 1
148
+ stride_kv_idx_h = 1
149
+ stride_kv_idx_m = 1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 1
245
+ stride_q_idx_h = 1
246
+ stride_q_idx_n = 1
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/42/c424arzgjg22xrcyl4orsbfthh3vxddttchjdd7yswdd5pdxdhtv.py ADDED
@@ -0,0 +1,1019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['3_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ns/cnsdsjpw2wa2anwrjd5vp4pkfq4mtbebhfmyzeol2hd4mzwu6qnk.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
44
+ # %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=tangents_2]
45
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
46
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
47
+ # return %buf0,%buf1
48
+ triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', '''
49
+ import triton
50
+ import triton.language as tl
51
+
52
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
53
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
54
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
55
+ triton_helpers.set_driver_to_gpu()
56
+
57
+ @triton_heuristics.persistent_reduction(
58
+ size_hints={'x': 4096, 'r0_': 128},
59
+ reduction_hint=ReductionHint.INNER,
60
+ filename=__file__,
61
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
62
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
63
+ )
64
+ @triton.jit
65
+ def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
66
+ r0_numel = 128
67
+ R0_BLOCK: tl.constexpr = 128
68
+ rnumel = r0_numel
69
+ RBLOCK: tl.constexpr = R0_BLOCK
70
+ xoffset = tl.program_id(0) * XBLOCK
71
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
72
+ xmask = xindex < xnumel
73
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
74
+ r0_offset = 0
75
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
76
+ roffset = r0_offset
77
+ rindex = r0_index
78
+ r0_2 = r0_index
79
+ x0 = (xindex % ks0)
80
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
81
+ x3 = xindex
82
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
83
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
84
+ tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
85
+ tmp2 = tmp0 * tmp1
86
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
87
+ tmp5 = tl.where(xmask, tmp3, 0)
88
+ tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
89
+ tmp7 = tmp6.to(tl.float32)
90
+ tmp9 = 0.6931471805599453
91
+ tmp10 = tmp8 * tmp9
92
+ tmp11 = 1.4426950408889634
93
+ tmp12 = tmp10 * tmp11
94
+ tmp13 = tmp7 - tmp12
95
+ tl.store(out_ptr1 + (x3), tmp13, xmask)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2]
104
+ # %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4]
105
+ # %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6]
106
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5]
111
+ # %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_10]
112
+ # %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_7]
113
+ # %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_13]
114
+ # %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_14]
115
+ # %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_11]
116
+ # %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_12]
117
+ # %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_15]
118
+ # %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_16]
119
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
120
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
121
+ # return %getitem_4
122
+ triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
123
+ import triton
124
+ import triton.language as tl
125
+
126
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
127
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
128
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
129
+
130
+ @triton_heuristics.template(
131
+
132
+ num_stages=3,
133
+ num_warps=8,
134
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
135
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
136
+
137
+ )
138
+ @triton.jit
139
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
140
+ PRESCALE_QK : tl.constexpr = False
141
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
142
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
143
+ WRITE_DQ : tl.constexpr = True
144
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
145
+ OUTPUT_MAX : tl.constexpr = False
146
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
147
+ IS_DIVISIBLE : tl.constexpr = False
148
+ SM_SCALE : tl.constexpr = 0.08838834764831845
149
+ GQA_SHARED_HEADS : tl.constexpr = 4
150
+ HAS_FULL_BLOCKS : tl.constexpr = True
151
+ QK_HEAD_DIM : tl.constexpr = 128
152
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
153
+ V_HEAD_DIM : tl.constexpr = 128
154
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
155
+ SAFE_HEAD_DIM : tl.constexpr = True
156
+ BLOCK_M1 : tl.constexpr = 64
157
+ BLOCK_N1 : tl.constexpr = 128
158
+ BLOCK_M2 : tl.constexpr = 128
159
+ BLOCK_N2 : tl.constexpr = 64
160
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
161
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
162
+ INDEX_DTYPE : tl.constexpr = tl.int32
163
+ Q = arg_Q
164
+ K = arg_K
165
+ V = arg_V
166
+ LSE = arg_LSE
167
+ DELTA = arg_DELTA
168
+ DO = arg_DO
169
+ DQ = arg_DQ
170
+ DV = arg_DV
171
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
172
+ KV_IDX = arg_KV_IDX
173
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
174
+ Q_IDX = arg_Q_IDX
175
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
176
+ FULL_KV_IDX = arg_FULL_KV_IDX
177
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
178
+ FULL_Q_IDX = arg_FULL_Q_IDX
179
+
180
+ # Sub notation for this kernel:
181
+ #
182
+ # Q: Query, K: Key, V: Value
183
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
184
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
185
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
186
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
187
+ # inductor codegen
188
+ # M: Number of queries, N: Number of keys/values
189
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
190
+ # V_HEAD_DIM: The dimension of the value embeddings
191
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
192
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
193
+ # (Modifiable) Performance tuning options
194
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
195
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
196
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
197
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
198
+ #
199
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
200
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
201
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
202
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
203
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
204
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
205
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
207
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
208
+
209
+ # The below are kernel options that can be applied for certain score_mods,
210
+ # or involve a numerics vs. perf tradeoff
211
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
212
+ # about 20% more numerical error, but slightly faster.
213
+
214
+ # Define strides of inputs
215
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
216
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
217
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
218
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
219
+
220
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
221
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
222
+
223
+ ZQ = 1
224
+ HQ = 32
225
+ HKV = 8
226
+ Q_LEN = ks0
227
+ ZKV = 1
228
+ KV_LEN = ks1
229
+
230
+ MATMUL_PRECISION = Q.dtype.element_ty
231
+
232
+ pid = tl.program_id(0).to(INDEX_DTYPE)
233
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
234
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
235
+
236
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
237
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
238
+ off_zkv = off_zq % ZKV # kv batch idx
239
+
240
+ SPARSE_Z = 1
241
+ SPARSE_HQ = 1
242
+
243
+ sparse_idx_z = off_zq % SPARSE_Z
244
+
245
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
246
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
247
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
248
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
249
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
250
+
251
+ # offset K, V, DV pointers for batch/kv-head
252
+ K += k_adj
253
+ V += v_adj
254
+ DV += dv_adj
255
+
256
+ RCP_LN2 = 1.44269504
257
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
258
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
259
+
260
+ if pid >= NUM_KV_BLOCKS:
261
+ off_pid = pid - NUM_KV_BLOCKS
262
+ # THIS BLOCK DOES DQ
263
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
264
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
265
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
266
+ start_m2_block = off_pid % NUM_Q_BLOCKS
267
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
268
+ stride_kv_num_blks_h = 1
269
+ stride_kv_idx_h = 1
270
+ stride_kv_idx_m = 1
271
+
272
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
273
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
274
+
275
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
276
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
277
+
278
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
279
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
280
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
281
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
282
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
283
+
284
+ Q2 = Q + q_adj2
285
+ DO2 = DO + do_adj2
286
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
287
+ # if Q is broadcasted)
288
+ DQ2 = DQ + dq_adj2
289
+ LSE2 = LSE + off_chz2
290
+ DELTA2 = DELTA + off_chz2
291
+
292
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
293
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
294
+
295
+ start_m2 = start_m2_block * BLOCK_M2
296
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
297
+
298
+ # load Q and do: they stay in SRAM throughout the inner loop.
299
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
300
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
301
+
302
+ if PRESCALE_QK:
303
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
304
+
305
+ if IS_DIVISIBLE:
306
+ Di = tl.load(DELTA2 + offs_m2)
307
+ lse = tl.load(LSE2 + offs_m2)
308
+ else:
309
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
310
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
312
+ lse = lse[:, None]
313
+
314
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
315
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
316
+ kv_indices = KV_IDX + sparse_kv_idx_offset
317
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
318
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
319
+
320
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
321
+ dq = bwd_dq_inner(
322
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
323
+ K, V,
324
+ dq, q, do, Di, lse,
325
+ off_zq, off_hq2, offs_m2, offs_n2,
326
+ stride_kn, stride_kd, stride_vn, stride_vd,
327
+ kv_indices, sparse_kv_num_blocks,
328
+ MATMUL_PRECISION,
329
+ IS_FULL_BLOCKS=False,
330
+ )
331
+
332
+ if HAS_FULL_BLOCKS:
333
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
334
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
335
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
336
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
337
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
338
+
339
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
340
+ dq = bwd_dq_inner(
341
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
342
+ K, V,
343
+ dq, q, do, Di, lse,
344
+ off_zq, off_hq2, offs_m2, offs_n2,
345
+ stride_kn, stride_kd, stride_vn, stride_vd,
346
+ kv_indices, sparse_kv_num_blocks,
347
+ MATMUL_PRECISION,
348
+ IS_FULL_BLOCKS=True,
349
+ )
350
+
351
+ # Write back dQ.
352
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
353
+ dq *= SM_SCALE
354
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
355
+ tl.store(dq_ptrs, dq)
356
+ else:
357
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
358
+ else:
359
+ # THIS BLOCK DOES DK & DV
360
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
361
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
362
+
363
+ pid_mask = pid // SPARSE_KV_MULTIPLE
364
+
365
+ stride_q_num_blks_h = 1
366
+ stride_q_idx_h = 1
367
+ stride_q_idx_n = 1
368
+
369
+
370
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
371
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+
373
+ start_n1 = pid * BLOCK_N1
374
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
375
+
376
+ # load K and V: they stay in SRAM throughout the inner loop.
377
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
378
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
379
+
380
+ if PRESCALE_QK:
381
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
382
+
383
+ for off_g in range(0, GQA_SHARED_HEADS):
384
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
385
+
386
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
387
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
388
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
389
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
390
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
391
+
392
+ Q1 = Q + q_adj1
393
+ DO1 = DO + do_adj1
394
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
395
+ # if Q is broadcasted)
396
+ LSE1 = LSE + off_chz1
397
+ DELTA1 = DELTA + off_chz1
398
+
399
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
400
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
401
+
402
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
403
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
404
+
405
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
406
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
407
+ q_indices = Q_IDX + sparse_q_idx_offset
408
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
409
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
410
+
411
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
412
+ dk, dv = bwd_dkdv_inner(
413
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
414
+ Q1, DO1, DELTA1, LSE1,
415
+ dk, dv, k, v,
416
+ off_zq, off_hq1, offs_n1, offs_m1,
417
+ stride_qm, stride_qd, stride_dom, stride_dod,
418
+ q_indices, sparse_q_num_blocks,
419
+ MATMUL_PRECISION,
420
+ IS_FULL_BLOCKS=False,
421
+ )
422
+
423
+
424
+ if HAS_FULL_BLOCKS:
425
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
426
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
427
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
428
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
429
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
430
+
431
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
432
+ dk, dv = bwd_dkdv_inner(
433
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
434
+ Q1, DO1, DELTA1, LSE1,
435
+ dk, dv, k, v,
436
+ off_zq, off_hq1, offs_n1, offs_m1,
437
+ stride_qm, stride_qd, stride_dom, stride_dod,
438
+ q_indices, sparse_q_num_blocks,
439
+ MATMUL_PRECISION,
440
+ IS_FULL_BLOCKS=True,
441
+ )
442
+
443
+ # Write back dV and dK.
444
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
445
+
446
+ index_n = offs_n1[:, None]
447
+ index_k = offs_k[None, :]
448
+ index_v = offs_v[None, :]
449
+
450
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
451
+ tl.store(dv_ptrs, dv)
452
+ else:
453
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
454
+
455
+ dk *= SM_SCALE
456
+
457
+ if SAFE_HEAD_DIM:
458
+ mask = index_n < KV_LEN
459
+ else:
460
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
461
+
462
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
463
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
464
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
465
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
466
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
467
+
468
+ @triton.jit
469
+ def bwd_dq_inner(
470
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
471
+ K, V, # pointers
472
+ dq, q, do, Di, lse,
473
+ off_z, off_hq, offs_m2, offs_n2,
474
+ stride_kn, stride_kd, stride_vn, stride_vd,
475
+ kv_indices, sparse_kv_num_blocks,
476
+ MATMUL_PRECISION,
477
+ IS_FULL_BLOCKS,
478
+ ):
479
+ PRESCALE_QK : tl.constexpr = False
480
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
481
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
482
+ WRITE_DQ : tl.constexpr = True
483
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
484
+ OUTPUT_MAX : tl.constexpr = False
485
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
486
+ IS_DIVISIBLE : tl.constexpr = False
487
+ SM_SCALE : tl.constexpr = 0.08838834764831845
488
+ GQA_SHARED_HEADS : tl.constexpr = 4
489
+ HAS_FULL_BLOCKS : tl.constexpr = True
490
+ QK_HEAD_DIM : tl.constexpr = 128
491
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
492
+ V_HEAD_DIM : tl.constexpr = 128
493
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
494
+ SAFE_HEAD_DIM : tl.constexpr = True
495
+ BLOCK_M1 : tl.constexpr = 64
496
+ BLOCK_N1 : tl.constexpr = 128
497
+ BLOCK_M2 : tl.constexpr = 128
498
+ BLOCK_N2 : tl.constexpr = 64
499
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
500
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
501
+ INDEX_DTYPE : tl.constexpr = tl.int32
502
+
503
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
504
+ RCP_LN2: tl.constexpr = 1.44269504
505
+ Q_LEN = ks0
506
+ KV_LEN = ks1
507
+
508
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
509
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
510
+
511
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
512
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
513
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
514
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
515
+
516
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
517
+
518
+ for start_n in range(0, hi):
519
+ dq = bwd_dq_block_mn(
520
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
521
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
522
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
523
+ stride_kn, stride_kd, stride_vn, stride_vd,
524
+ kv_indices, sparse_kv_num_blocks,
525
+ MATMUL_PRECISION, RCP_LN2,
526
+ IS_FULL_BLOCKS,
527
+ )
528
+
529
+ # Increment pointers.
530
+ offset = get_offset_for_next_block(
531
+ start_n, kv_indices, sparse_kv_num_blocks,
532
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
533
+ )
534
+
535
+ kT_ptrs += offset * stride_kn
536
+ vT_ptrs += offset * stride_vn
537
+
538
+ offs_n2 += offset
539
+
540
+ return dq
541
+
542
+
543
+ @triton.jit
544
+ def bwd_dq_block_mn(
545
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
546
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
547
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
548
+ stride_kn, stride_kd, stride_vn, stride_vd,
549
+ kv_indices, sparse_kv_num_blocks,
550
+ MATMUL_PRECISION, RCP_LN2,
551
+ IS_FULL_BLOCKS,
552
+ ):
553
+ PRESCALE_QK : tl.constexpr = False
554
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
555
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
556
+ WRITE_DQ : tl.constexpr = True
557
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
558
+ OUTPUT_MAX : tl.constexpr = False
559
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
560
+ IS_DIVISIBLE : tl.constexpr = False
561
+ SM_SCALE : tl.constexpr = 0.08838834764831845
562
+ GQA_SHARED_HEADS : tl.constexpr = 4
563
+ HAS_FULL_BLOCKS : tl.constexpr = True
564
+ QK_HEAD_DIM : tl.constexpr = 128
565
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ V_HEAD_DIM : tl.constexpr = 128
567
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
568
+ SAFE_HEAD_DIM : tl.constexpr = True
569
+ BLOCK_M1 : tl.constexpr = 64
570
+ BLOCK_N1 : tl.constexpr = 128
571
+ BLOCK_M2 : tl.constexpr = 128
572
+ BLOCK_N2 : tl.constexpr = 64
573
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
574
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
575
+ INDEX_DTYPE : tl.constexpr = tl.int32
576
+
577
+
578
+ # NB reversed order to since K is transposed
579
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
580
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
581
+ if not PRESCALE_QK:
582
+ qk *= SM_SCALE
583
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
584
+ pre_mod_scores = qk
585
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
586
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
587
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
588
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
589
+
590
+ tmp0 = (qk)
591
+ post_mod_scores = tmp0
592
+
593
+
594
+
595
+
596
+ if not IS_DIVISIBLE:
597
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
598
+
599
+ if not IS_FULL_BLOCKS:
600
+ tmp1 = (m)
601
+ tmp2 = tl.full([1], 0, tl.int32)
602
+ tmp3 = tmp1 < tmp2
603
+ tmp4 = (n)
604
+ tmp5 = tmp4 <= tmp1
605
+ tmp6 = tmp3 & tmp5
606
+ tmp7 = tmp1 >= tmp2
607
+ tmp8 = tmp4 < tmp2
608
+ tmp9 = tmp7 & tmp8
609
+ tmp10 = tmp8 == 0
610
+ tmp11 = tmp7 & tmp10
611
+ tmp12 = tmp1 - tmp2
612
+ tmp13 = tl.full([1], 16, tl.int32)
613
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
614
+ tmp15 = tmp4 - tmp2
615
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
616
+ tmp17 = tmp14 == tmp16
617
+ tmp18 = tmp11 & tmp17
618
+ tmp19 = tmp9 | tmp18
619
+ tmp20 = tmp6 | tmp19
620
+ mask_mod_output = tmp20
621
+
622
+
623
+ # apply mask for partial masked block
624
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
625
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
626
+ if not PRESCALE_QK:
627
+ post_mod_scores *= RCP_LN2
628
+ p = tl.math.exp2(post_mod_scores - lse)
629
+ # Compute dP and dS.
630
+ # NB reversed order to since V is transposed
631
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
632
+
633
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
634
+ ds = p * (dp - Di[:, None])
635
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
636
+ tmp21 = (ds)
637
+ grad_scores = tmp21
638
+
639
+
640
+ if not IS_DIVISIBLE:
641
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
642
+
643
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
644
+ if WRITE_DQ:
645
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
646
+
647
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
648
+ ds = grad_scores
649
+
650
+ if not IS_FULL_BLOCKS:
651
+ # (grads) apply mask for partially unmasked block
652
+ ds = tl.where(mask_mod_output, ds, 0.0)
653
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
654
+ ds = ds.to(MATMUL_PRECISION)
655
+ # Compute dQ.
656
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
657
+
658
+ return dq
659
+
660
+
661
+ @triton.jit
662
+ def bwd_dkdv_inner(
663
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
664
+ Q, DO, DELTA, LSE, # pointers
665
+ dk, dv, k, v,
666
+ off_z, off_hq, offs_n1, offs_m1,
667
+ stride_qm, stride_qd, stride_dom, stride_dod,
668
+ q_indices, sparse_q_num_blocks,
669
+ MATMUL_PRECISION,
670
+ IS_FULL_BLOCKS,
671
+ ):
672
+ PRESCALE_QK : tl.constexpr = False
673
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
674
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
675
+ WRITE_DQ : tl.constexpr = True
676
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
677
+ OUTPUT_MAX : tl.constexpr = False
678
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
679
+ IS_DIVISIBLE : tl.constexpr = False
680
+ SM_SCALE : tl.constexpr = 0.08838834764831845
681
+ GQA_SHARED_HEADS : tl.constexpr = 4
682
+ HAS_FULL_BLOCKS : tl.constexpr = True
683
+ QK_HEAD_DIM : tl.constexpr = 128
684
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
685
+ V_HEAD_DIM : tl.constexpr = 128
686
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
687
+ SAFE_HEAD_DIM : tl.constexpr = True
688
+ BLOCK_M1 : tl.constexpr = 64
689
+ BLOCK_N1 : tl.constexpr = 128
690
+ BLOCK_M2 : tl.constexpr = 128
691
+ BLOCK_N2 : tl.constexpr = 64
692
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
693
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
694
+ INDEX_DTYPE : tl.constexpr = tl.int32
695
+
696
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
697
+ RCP_LN2: tl.constexpr = 1.44269504
698
+ Q_LEN = ks0
699
+ KV_LEN = ks1
700
+
701
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
702
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
703
+
704
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
705
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
706
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
707
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
708
+
709
+ # The minimum is needed to handle the case where we run with a super large
710
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
711
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
712
+
713
+ for start_m in range(0, hi):
714
+ dk, dv = bwd_dkdv_block_mn(
715
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
716
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
717
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
718
+ stride_qm, stride_qd, stride_dom, stride_dod,
719
+ q_indices, sparse_q_num_blocks,
720
+ MATMUL_PRECISION, RCP_LN2,
721
+ IS_FULL_BLOCKS,
722
+ )
723
+ # Increment pointers.
724
+ offset = get_offset_for_next_block(
725
+ start_m, q_indices, sparse_q_num_blocks,
726
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
727
+ )
728
+
729
+ qT_ptrs += offset * stride_qm
730
+ do_ptrs += offset * stride_dom
731
+ offs_m1 += offset
732
+
733
+ return dk, dv
734
+
735
+
736
+ @triton.jit
737
+ def bwd_dkdv_block_mn(
738
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
739
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
740
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
741
+ stride_qm, stride_qd, stride_dom, stride_dod,
742
+ q_indices, sparse_q_num_blocks,
743
+ MATMUL_PRECISION, RCP_LN2,
744
+ IS_FULL_BLOCKS,
745
+ ):
746
+ PRESCALE_QK : tl.constexpr = False
747
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
748
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
749
+ WRITE_DQ : tl.constexpr = True
750
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
751
+ OUTPUT_MAX : tl.constexpr = False
752
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
753
+ IS_DIVISIBLE : tl.constexpr = False
754
+ SM_SCALE : tl.constexpr = 0.08838834764831845
755
+ GQA_SHARED_HEADS : tl.constexpr = 4
756
+ HAS_FULL_BLOCKS : tl.constexpr = True
757
+ QK_HEAD_DIM : tl.constexpr = 128
758
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
759
+ V_HEAD_DIM : tl.constexpr = 128
760
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
761
+ SAFE_HEAD_DIM : tl.constexpr = True
762
+ BLOCK_M1 : tl.constexpr = 64
763
+ BLOCK_N1 : tl.constexpr = 128
764
+ BLOCK_M2 : tl.constexpr = 128
765
+ BLOCK_N2 : tl.constexpr = 64
766
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
767
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
768
+ INDEX_DTYPE : tl.constexpr = tl.int32
769
+
770
+
771
+ # NB reversed order since Q is transposed
772
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
773
+ # Load LSE before computing qk to reduce pipeline stall.
774
+ if IS_DIVISIBLE:
775
+ lse = tl.load(LSE + offs_m1)
776
+ else:
777
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
778
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
779
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
780
+ if not PRESCALE_QK:
781
+ qkT *= SM_SCALE
782
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
783
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
784
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
785
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
786
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
787
+
788
+ pre_mod_scores = qkT
789
+ tmp22 = (qkT)
790
+ post_mod_scores = tmp22
791
+
792
+
793
+
794
+ if not IS_DIVISIBLE:
795
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
796
+
797
+ if not IS_FULL_BLOCKS:
798
+ tmp23 = (m)
799
+ tmp24 = tl.full([1], 0, tl.int32)
800
+ tmp25 = tmp23 < tmp24
801
+ tmp26 = (n)
802
+ tmp27 = tmp26 <= tmp23
803
+ tmp28 = tmp25 & tmp27
804
+ tmp29 = tmp23 >= tmp24
805
+ tmp30 = tmp26 < tmp24
806
+ tmp31 = tmp29 & tmp30
807
+ tmp32 = tmp30 == 0
808
+ tmp33 = tmp29 & tmp32
809
+ tmp34 = tmp23 - tmp24
810
+ tmp35 = tl.full([1], 16, tl.int32)
811
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
812
+ tmp37 = tmp26 - tmp24
813
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
814
+ tmp39 = tmp36 == tmp38
815
+ tmp40 = tmp33 & tmp39
816
+ tmp41 = tmp31 | tmp40
817
+ tmp42 = tmp28 | tmp41
818
+ mask_mod_output = tmp42
819
+
820
+ # (grads) apply mask for fully masked block
821
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
822
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
823
+ if not PRESCALE_QK:
824
+ post_mod_scores *= RCP_LN2
825
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
826
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
827
+ # Compute dV.
828
+ ppT = pT
829
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
830
+ if IS_DIVISIBLE:
831
+ Di = tl.load(DELTA + offs_m1)
832
+ else:
833
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
834
+ # Compute dP and dS.
835
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
836
+ dsT = pT * (dpT - Di[None, :])
837
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
838
+ tmp43 = (dsT)
839
+ grad_scores = tmp43
840
+
841
+
842
+
843
+ if not IS_DIVISIBLE:
844
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
845
+
846
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
847
+ if not WRITE_DQ:
848
+ idx_b = off_z
849
+ idx_h = off_hq
850
+ idx_m = m
851
+ idx_n = n
852
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
853
+
854
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
855
+ dsT = grad_scores
856
+ if not IS_FULL_BLOCKS:
857
+ # (grads) apply mask for partially unmasked block
858
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
861
+
862
+ return dk, dv
863
+
864
+ # Utility triton funcs
865
+ @triton.jit
866
+ def get_offset_for_next_block(
867
+ loop_iter, col_indices, total_blocks,
868
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
869
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
870
+ ):
871
+ if BLOCKS_ARE_CONTIGUOUS:
872
+ return BLOCK
873
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
874
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
875
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
876
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
877
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
878
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
879
+ return offset
880
+
881
+ @triton.jit
882
+ def get_bounded_indices(indices, max_len=None):
883
+ return indices % max_len if max_len is not None else indices
884
+
885
+ @triton.jit
886
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
887
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
888
+ return tl.load(block_ptr)
889
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
890
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
891
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
892
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
893
+ else:
894
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
895
+
896
+ @triton.jit
897
+ def load_checked_2d(
898
+ ptr,
899
+ offs_m,
900
+ offs_n,
901
+ stride_m,
902
+ stride_n,
903
+ IS_DIVISIBLE_M: tl.constexpr,
904
+ IS_DIVISIBLE_N: tl.constexpr,
905
+ M_LEN: tl.constexpr,
906
+ N_LEN: tl.constexpr,
907
+ ):
908
+ # Calculate final pointer if strides are provided
909
+ if stride_m is not None and stride_n is not None:
910
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
911
+
912
+ # Handle all masking cases
913
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
914
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
915
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
916
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
917
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
918
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
919
+ else: # Both divisible
920
+ return tl.load(ptr)
921
+ ''', device_str='cuda')
922
+
923
+
924
+ async_compile.wait(globals())
925
+ del async_compile
926
+
927
+ class Runner:
928
+ def __init__(self, partitions):
929
+ self.partitions = partitions
930
+
931
+ def recursively_apply_fns(self, fns):
932
+ new_callables = []
933
+ for fn, c in zip(fns, self.partitions):
934
+ new_callables.append(fn(c))
935
+ self.partitions = new_callables
936
+
937
+ def call(self, args):
938
+ primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args
939
+ args.clear()
940
+ s37 = primals_8
941
+ s0 = primals_9
942
+ assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
943
+ assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
944
+ assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
945
+ assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
946
+ assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
947
+ assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
948
+ assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
949
+ assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
950
+ assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
951
+ assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
952
+ assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
953
+ assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
954
+ assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
955
+ assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
956
+ assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
957
+ with torch.cuda._DeviceGuard(7):
958
+ torch.cuda.set_device(7)
959
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
960
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
961
+ triton_per_fused_mul_0_xnumel = 32*s37
962
+ stream7 = get_raw_stream(7)
963
+ triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream7)
964
+ del getitem
965
+ del tangents_2
966
+ buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
967
+ buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
968
+ buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
969
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
970
+ stream7 = get_raw_stream(7)
971
+ triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream7)
972
+ del buf1
973
+ del getitem_1
974
+ del primals_10
975
+ del primals_11
976
+ del primals_12
977
+ del primals_13
978
+ del primals_14
979
+ del primals_15
980
+ del primals_16
981
+ del primals_2
982
+ del primals_4
983
+ del primals_6
984
+ del primals_7
985
+ del tangents_1
986
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, )
987
+
988
+ runner = Runner(partitions=[])
989
+ call = runner.call
990
+ recursively_apply_fns = runner.recursively_apply_fns
991
+
992
+
993
+ def benchmark_compiled_module(times=10, repeat=10):
994
+ from torch._dynamo.testing import rand_strided
995
+ from torch._inductor.utils import print_performance
996
+ primals_8 = 80
997
+ primals_9 = 80
998
+ primals_2 = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
999
+ primals_4 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
1000
+ primals_6 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
1001
+ primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1002
+ primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1003
+ primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1004
+ primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1005
+ primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1006
+ primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1007
+ primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1008
+ primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1009
+ getitem = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
1010
+ getitem_1 = rand_strided((1, 32, 80), (2560, 80, 1), device='cuda:7', dtype=torch.float32)
1011
+ tangents_1 = rand_strided((1, 32, 80, 128), (327680, 10240, 128, 1), device='cuda:7', dtype=torch.bfloat16)
1012
+ tangents_2 = rand_strided((1, 32, 80), (2560, 80, 1), device='cuda:7', dtype=torch.float32)
1013
+ fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2])
1014
+ return print_performance(fn, times=times, repeat=repeat)
1015
+
1016
+
1017
+ if __name__ == "__main__":
1018
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1019
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/44/3728d77fd47f8b1056ec8670d5b1bd262db03ae9994292fee6203d32e3d9cd03.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"}
progress/SpecForge/cache/compiled_kernels/44/c44m5klhlzg7nfvzfelnbb3hjh2jwzh2e5yyk3vtcvhyw6rbnjo6.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 524288, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x5 = xindex
33
+ x1 = xindex // 128
34
+ x0 = (xindex % 128)
35
+ x3 = ((xindex // 128) % ks0)
36
+ x4 = xindex // ks1
37
+ tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
38
+ tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
39
+ tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
40
+ tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
41
+ tmp2 = float("-inf")
42
+ tmp3 = tmp1 == tmp2
43
+ tmp5 = tmp4 - tmp1
44
+ tmp6 = 0.0
45
+ tmp7 = tl.where(tmp3, tmp6, tmp5)
46
+ tmp8 = libdevice.exp2(tmp7)
47
+ tmp9 = tmp0 * tmp8
48
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
49
+ tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
50
+ tmp14 = 1.0
51
+ tmp15 = tl.where(tmp3, tmp14, tmp13)
52
+ tmp16 = (tmp12 / tmp15)
53
+ tmp17 = tmp16.to(tl.float32)
54
+ tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
progress/SpecForge/cache/compiled_kernels/4d/c4d7fh2egdfps7aogbncwlp3ihfwtff243bbobq7vrxj2m2grl64.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 32768, 'r0_': 128},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 128
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
29
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
30
+ x3 = xindex
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_2 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
38
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
39
+ tmp2 = tmp0 * tmp1
40
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
41
+ tmp5 = _tmp4 + tmp3
42
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
43
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
44
+ tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
45
+ tmp6 = tmp4.to(tl.float32)
46
+ tmp8 = 0.6931471805599453
47
+ tmp9 = tmp7 * tmp8
48
+ tmp10 = 1.4426950408889634
49
+ tmp11 = tmp9 * tmp10
50
+ tmp12 = tmp6 - tmp11
51
+ tl.store(out_ptr1 + (x3), tmp12, xmask)
progress/SpecForge/cache/compiled_kernels/4d/fd68b3c1a3fd19883dc58697393b6044e6217afda9ea11f84bd620545197dd6b.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 40, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"}
progress/SpecForge/cache/compiled_kernels/4h/c4h32peoig2erjdxibxrq3sbpm533ci3z57ntqjhdemzxp2rhysl.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 1
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 1
148
+ stride_kv_idx_h = 1
149
+ stride_kv_idx_m = 1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 1
245
+ stride_q_idx_h = 1
246
+ stride_q_idx_n = 1
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = ks0
578
+ KV_LEN = ks1
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/4k/c4korm4huj2wookuw6gikboxrsp3m5yt45c7fxucyujswm5fgb3u.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 1
130
+ stride_kv_idx_h = 1
131
+ stride_kv_idx_m = 1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/4x/31f9d1ee4882fe2005f02592ea2d9f20a1835b42c5baefd7795e8640f97fdc16.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
progress/SpecForge/cache/compiled_kernels/4x/c4xjhgyzut6anhrjeinspoinohfxvyl6skr4gd3vfrscrvsevmya.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 32768},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x0 = (xindex % ks0)
24
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
25
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
26
+ tmp1 = 0.6931471805599453
27
+ tmp2 = tmp0 * tmp1
28
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/5b/c5blvz5sxoj2veuexokuub2zm2pg4l2nqbbny4rr2jhsiiyw6njy.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 1
130
+ stride_kv_idx_h = 1
131
+ stride_kv_idx_m = 1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/5g/c5g7nnbi3zupsx7kdee2ed6g2fgrtd2jxyggsjpckfg5p7rps4qm.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 1
130
+ stride_kv_idx_h = 1
131
+ stride_kv_idx_m = 1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/5j/3cc65a0fdb544c73efb7240355b77da3f1ab394b46f272fa923c368e6cc63c34.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 96, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"}
progress/SpecForge/cache/compiled_kernels/5j/c5j7yk5hlaaxs42qwjlmoczwtoukaw2dio2o6p7qfekdy5upikyv.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 524288, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x5 = xindex
33
+ x1 = xindex // 128
34
+ x0 = (xindex % 128)
35
+ x3 = ((xindex // 128) % ks0)
36
+ x4 = xindex // ks1
37
+ tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
38
+ tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
39
+ tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
40
+ tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
41
+ tmp2 = float("-inf")
42
+ tmp3 = tmp1 == tmp2
43
+ tmp5 = tmp4 - tmp1
44
+ tmp6 = 0.0
45
+ tmp7 = tl.where(tmp3, tmp6, tmp5)
46
+ tmp8 = libdevice.exp2(tmp7)
47
+ tmp9 = tmp0 * tmp8
48
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
49
+ tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
50
+ tmp14 = 1.0
51
+ tmp15 = tl.where(tmp3, tmp14, tmp13)
52
+ tmp16 = (tmp12 / tmp15)
53
+ tmp17 = tmp16.to(tl.float32)
54
+ tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
progress/SpecForge/cache/compiled_kernels/5w/c5wutjfcact264ykgcamj2asvz4eqe3ygz47upjgib2qw5rnnihu.py ADDED
@@ -0,0 +1,1019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ns/cnsdsjpw2wa2anwrjd5vp4pkfq4mtbebhfmyzeol2hd4mzwu6qnk.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
44
+ # %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=tangents_2]
45
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
46
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
47
+ # return %buf0,%buf1
48
+ triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', '''
49
+ import triton
50
+ import triton.language as tl
51
+
52
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
53
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
54
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
55
+ triton_helpers.set_driver_to_gpu()
56
+
57
+ @triton_heuristics.persistent_reduction(
58
+ size_hints={'x': 4096, 'r0_': 128},
59
+ reduction_hint=ReductionHint.INNER,
60
+ filename=__file__,
61
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
62
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
63
+ )
64
+ @triton.jit
65
+ def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
66
+ r0_numel = 128
67
+ R0_BLOCK: tl.constexpr = 128
68
+ rnumel = r0_numel
69
+ RBLOCK: tl.constexpr = R0_BLOCK
70
+ xoffset = tl.program_id(0) * XBLOCK
71
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
72
+ xmask = xindex < xnumel
73
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
74
+ r0_offset = 0
75
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
76
+ roffset = r0_offset
77
+ rindex = r0_index
78
+ r0_2 = r0_index
79
+ x0 = (xindex % ks0)
80
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
81
+ x3 = xindex
82
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
83
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
84
+ tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
85
+ tmp2 = tmp0 * tmp1
86
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
87
+ tmp5 = tl.where(xmask, tmp3, 0)
88
+ tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
89
+ tmp7 = tmp6.to(tl.float32)
90
+ tmp9 = 0.6931471805599453
91
+ tmp10 = tmp8 * tmp9
92
+ tmp11 = 1.4426950408889634
93
+ tmp12 = tmp10 * tmp11
94
+ tmp13 = tmp7 - tmp12
95
+ tl.store(out_ptr1 + (x3), tmp13, xmask)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2]
104
+ # %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4]
105
+ # %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6]
106
+ # %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5]
111
+ # %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_10]
112
+ # %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_7]
113
+ # %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_13]
114
+ # %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_14]
115
+ # %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_11]
116
+ # %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_12]
117
+ # %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_15]
118
+ # %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_16]
119
+ # %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
120
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
121
+ # return %getitem_4
122
+ triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
123
+ import triton
124
+ import triton.language as tl
125
+
126
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
127
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
128
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
129
+
130
+ @triton_heuristics.template(
131
+
132
+ num_stages=3,
133
+ num_warps=8,
134
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
135
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
136
+
137
+ )
138
+ @triton.jit
139
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
140
+ PRESCALE_QK : tl.constexpr = False
141
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
142
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
143
+ WRITE_DQ : tl.constexpr = True
144
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
145
+ OUTPUT_MAX : tl.constexpr = False
146
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
147
+ IS_DIVISIBLE : tl.constexpr = False
148
+ SM_SCALE : tl.constexpr = 0.08838834764831845
149
+ GQA_SHARED_HEADS : tl.constexpr = 4
150
+ HAS_FULL_BLOCKS : tl.constexpr = True
151
+ QK_HEAD_DIM : tl.constexpr = 128
152
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
153
+ V_HEAD_DIM : tl.constexpr = 128
154
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
155
+ SAFE_HEAD_DIM : tl.constexpr = True
156
+ BLOCK_M1 : tl.constexpr = 64
157
+ BLOCK_N1 : tl.constexpr = 128
158
+ BLOCK_M2 : tl.constexpr = 128
159
+ BLOCK_N2 : tl.constexpr = 64
160
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
161
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
162
+ INDEX_DTYPE : tl.constexpr = tl.int32
163
+ Q = arg_Q
164
+ K = arg_K
165
+ V = arg_V
166
+ LSE = arg_LSE
167
+ DELTA = arg_DELTA
168
+ DO = arg_DO
169
+ DQ = arg_DQ
170
+ DV = arg_DV
171
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
172
+ KV_IDX = arg_KV_IDX
173
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
174
+ Q_IDX = arg_Q_IDX
175
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
176
+ FULL_KV_IDX = arg_FULL_KV_IDX
177
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
178
+ FULL_Q_IDX = arg_FULL_Q_IDX
179
+
180
+ # Sub notation for this kernel:
181
+ #
182
+ # Q: Query, K: Key, V: Value
183
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
184
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
185
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
186
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
187
+ # inductor codegen
188
+ # M: Number of queries, N: Number of keys/values
189
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
190
+ # V_HEAD_DIM: The dimension of the value embeddings
191
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
192
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
193
+ # (Modifiable) Performance tuning options
194
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
195
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
196
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
197
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
198
+ #
199
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
200
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
201
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
202
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
203
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
204
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
205
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
207
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
208
+
209
+ # The below are kernel options that can be applied for certain score_mods,
210
+ # or involve a numerics vs. perf tradeoff
211
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
212
+ # about 20% more numerical error, but slightly faster.
213
+
214
+ # Define strides of inputs
215
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
216
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
217
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
218
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
219
+
220
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
221
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
222
+
223
+ ZQ = 1
224
+ HQ = 32
225
+ HKV = 8
226
+ Q_LEN = ks0
227
+ ZKV = 1
228
+ KV_LEN = ks1
229
+
230
+ MATMUL_PRECISION = Q.dtype.element_ty
231
+
232
+ pid = tl.program_id(0).to(INDEX_DTYPE)
233
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
234
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
235
+
236
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
237
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
238
+ off_zkv = off_zq % ZKV # kv batch idx
239
+
240
+ SPARSE_Z = 1
241
+ SPARSE_HQ = 1
242
+
243
+ sparse_idx_z = off_zq % SPARSE_Z
244
+
245
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
246
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
247
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
248
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
249
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
250
+
251
+ # offset K, V, DV pointers for batch/kv-head
252
+ K += k_adj
253
+ V += v_adj
254
+ DV += dv_adj
255
+
256
+ RCP_LN2 = 1.44269504
257
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
258
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
259
+
260
+ if pid >= NUM_KV_BLOCKS:
261
+ off_pid = pid - NUM_KV_BLOCKS
262
+ # THIS BLOCK DOES DQ
263
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
264
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
265
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
266
+ start_m2_block = off_pid % NUM_Q_BLOCKS
267
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
268
+ stride_kv_num_blks_h = 1
269
+ stride_kv_idx_h = 1
270
+ stride_kv_idx_m = 1
271
+
272
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
273
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
274
+
275
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
276
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
277
+
278
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
279
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
280
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
281
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
282
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
283
+
284
+ Q2 = Q + q_adj2
285
+ DO2 = DO + do_adj2
286
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
287
+ # if Q is broadcasted)
288
+ DQ2 = DQ + dq_adj2
289
+ LSE2 = LSE + off_chz2
290
+ DELTA2 = DELTA + off_chz2
291
+
292
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
293
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
294
+
295
+ start_m2 = start_m2_block * BLOCK_M2
296
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
297
+
298
+ # load Q and do: they stay in SRAM throughout the inner loop.
299
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
300
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
301
+
302
+ if PRESCALE_QK:
303
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
304
+
305
+ if IS_DIVISIBLE:
306
+ Di = tl.load(DELTA2 + offs_m2)
307
+ lse = tl.load(LSE2 + offs_m2)
308
+ else:
309
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
310
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
312
+ lse = lse[:, None]
313
+
314
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
315
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
316
+ kv_indices = KV_IDX + sparse_kv_idx_offset
317
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
318
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
319
+
320
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
321
+ dq = bwd_dq_inner(
322
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
323
+ K, V,
324
+ dq, q, do, Di, lse,
325
+ off_zq, off_hq2, offs_m2, offs_n2,
326
+ stride_kn, stride_kd, stride_vn, stride_vd,
327
+ kv_indices, sparse_kv_num_blocks,
328
+ MATMUL_PRECISION,
329
+ IS_FULL_BLOCKS=False,
330
+ )
331
+
332
+ if HAS_FULL_BLOCKS:
333
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
334
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
335
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
336
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
337
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
338
+
339
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
340
+ dq = bwd_dq_inner(
341
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
342
+ K, V,
343
+ dq, q, do, Di, lse,
344
+ off_zq, off_hq2, offs_m2, offs_n2,
345
+ stride_kn, stride_kd, stride_vn, stride_vd,
346
+ kv_indices, sparse_kv_num_blocks,
347
+ MATMUL_PRECISION,
348
+ IS_FULL_BLOCKS=True,
349
+ )
350
+
351
+ # Write back dQ.
352
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
353
+ dq *= SM_SCALE
354
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
355
+ tl.store(dq_ptrs, dq)
356
+ else:
357
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
358
+ else:
359
+ # THIS BLOCK DOES DK & DV
360
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
361
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
362
+
363
+ pid_mask = pid // SPARSE_KV_MULTIPLE
364
+
365
+ stride_q_num_blks_h = 1
366
+ stride_q_idx_h = 1
367
+ stride_q_idx_n = 1
368
+
369
+
370
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
371
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+
373
+ start_n1 = pid * BLOCK_N1
374
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
375
+
376
+ # load K and V: they stay in SRAM throughout the inner loop.
377
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
378
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
379
+
380
+ if PRESCALE_QK:
381
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
382
+
383
+ for off_g in range(0, GQA_SHARED_HEADS):
384
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
385
+
386
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
387
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
388
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
389
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
390
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
391
+
392
+ Q1 = Q + q_adj1
393
+ DO1 = DO + do_adj1
394
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
395
+ # if Q is broadcasted)
396
+ LSE1 = LSE + off_chz1
397
+ DELTA1 = DELTA + off_chz1
398
+
399
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
400
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
401
+
402
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
403
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
404
+
405
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
406
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
407
+ q_indices = Q_IDX + sparse_q_idx_offset
408
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
409
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
410
+
411
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
412
+ dk, dv = bwd_dkdv_inner(
413
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
414
+ Q1, DO1, DELTA1, LSE1,
415
+ dk, dv, k, v,
416
+ off_zq, off_hq1, offs_n1, offs_m1,
417
+ stride_qm, stride_qd, stride_dom, stride_dod,
418
+ q_indices, sparse_q_num_blocks,
419
+ MATMUL_PRECISION,
420
+ IS_FULL_BLOCKS=False,
421
+ )
422
+
423
+
424
+ if HAS_FULL_BLOCKS:
425
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
426
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
427
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
428
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
429
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
430
+
431
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
432
+ dk, dv = bwd_dkdv_inner(
433
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
434
+ Q1, DO1, DELTA1, LSE1,
435
+ dk, dv, k, v,
436
+ off_zq, off_hq1, offs_n1, offs_m1,
437
+ stride_qm, stride_qd, stride_dom, stride_dod,
438
+ q_indices, sparse_q_num_blocks,
439
+ MATMUL_PRECISION,
440
+ IS_FULL_BLOCKS=True,
441
+ )
442
+
443
+ # Write back dV and dK.
444
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
445
+
446
+ index_n = offs_n1[:, None]
447
+ index_k = offs_k[None, :]
448
+ index_v = offs_v[None, :]
449
+
450
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
451
+ tl.store(dv_ptrs, dv)
452
+ else:
453
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
454
+
455
+ dk *= SM_SCALE
456
+
457
+ if SAFE_HEAD_DIM:
458
+ mask = index_n < KV_LEN
459
+ else:
460
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
461
+
462
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
463
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
464
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
465
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
466
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
467
+
468
+ @triton.jit
469
+ def bwd_dq_inner(
470
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
471
+ K, V, # pointers
472
+ dq, q, do, Di, lse,
473
+ off_z, off_hq, offs_m2, offs_n2,
474
+ stride_kn, stride_kd, stride_vn, stride_vd,
475
+ kv_indices, sparse_kv_num_blocks,
476
+ MATMUL_PRECISION,
477
+ IS_FULL_BLOCKS,
478
+ ):
479
+ PRESCALE_QK : tl.constexpr = False
480
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
481
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
482
+ WRITE_DQ : tl.constexpr = True
483
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
484
+ OUTPUT_MAX : tl.constexpr = False
485
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
486
+ IS_DIVISIBLE : tl.constexpr = False
487
+ SM_SCALE : tl.constexpr = 0.08838834764831845
488
+ GQA_SHARED_HEADS : tl.constexpr = 4
489
+ HAS_FULL_BLOCKS : tl.constexpr = True
490
+ QK_HEAD_DIM : tl.constexpr = 128
491
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
492
+ V_HEAD_DIM : tl.constexpr = 128
493
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
494
+ SAFE_HEAD_DIM : tl.constexpr = True
495
+ BLOCK_M1 : tl.constexpr = 64
496
+ BLOCK_N1 : tl.constexpr = 128
497
+ BLOCK_M2 : tl.constexpr = 128
498
+ BLOCK_N2 : tl.constexpr = 64
499
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
500
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
501
+ INDEX_DTYPE : tl.constexpr = tl.int32
502
+
503
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
504
+ RCP_LN2: tl.constexpr = 1.44269504
505
+ Q_LEN = ks0
506
+ KV_LEN = ks1
507
+
508
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
509
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
510
+
511
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
512
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
513
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
514
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
515
+
516
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
517
+
518
+ for start_n in range(0, hi):
519
+ dq = bwd_dq_block_mn(
520
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
521
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
522
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
523
+ stride_kn, stride_kd, stride_vn, stride_vd,
524
+ kv_indices, sparse_kv_num_blocks,
525
+ MATMUL_PRECISION, RCP_LN2,
526
+ IS_FULL_BLOCKS,
527
+ )
528
+
529
+ # Increment pointers.
530
+ offset = get_offset_for_next_block(
531
+ start_n, kv_indices, sparse_kv_num_blocks,
532
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
533
+ )
534
+
535
+ kT_ptrs += offset * stride_kn
536
+ vT_ptrs += offset * stride_vn
537
+
538
+ offs_n2 += offset
539
+
540
+ return dq
541
+
542
+
543
+ @triton.jit
544
+ def bwd_dq_block_mn(
545
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
546
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
547
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
548
+ stride_kn, stride_kd, stride_vn, stride_vd,
549
+ kv_indices, sparse_kv_num_blocks,
550
+ MATMUL_PRECISION, RCP_LN2,
551
+ IS_FULL_BLOCKS,
552
+ ):
553
+ PRESCALE_QK : tl.constexpr = False
554
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
555
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
556
+ WRITE_DQ : tl.constexpr = True
557
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
558
+ OUTPUT_MAX : tl.constexpr = False
559
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
560
+ IS_DIVISIBLE : tl.constexpr = False
561
+ SM_SCALE : tl.constexpr = 0.08838834764831845
562
+ GQA_SHARED_HEADS : tl.constexpr = 4
563
+ HAS_FULL_BLOCKS : tl.constexpr = True
564
+ QK_HEAD_DIM : tl.constexpr = 128
565
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ V_HEAD_DIM : tl.constexpr = 128
567
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
568
+ SAFE_HEAD_DIM : tl.constexpr = True
569
+ BLOCK_M1 : tl.constexpr = 64
570
+ BLOCK_N1 : tl.constexpr = 128
571
+ BLOCK_M2 : tl.constexpr = 128
572
+ BLOCK_N2 : tl.constexpr = 64
573
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
574
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
575
+ INDEX_DTYPE : tl.constexpr = tl.int32
576
+
577
+
578
+ # NB reversed order to since K is transposed
579
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
580
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
581
+ if not PRESCALE_QK:
582
+ qk *= SM_SCALE
583
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
584
+ pre_mod_scores = qk
585
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
586
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
587
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
588
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
589
+
590
+ tmp0 = (qk)
591
+ post_mod_scores = tmp0
592
+
593
+
594
+
595
+
596
+ if not IS_DIVISIBLE:
597
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
598
+
599
+ if not IS_FULL_BLOCKS:
600
+ tmp1 = (m)
601
+ tmp2 = tl.full([1], 0, tl.int32)
602
+ tmp3 = tmp1 < tmp2
603
+ tmp4 = (n)
604
+ tmp5 = tmp4 <= tmp1
605
+ tmp6 = tmp3 & tmp5
606
+ tmp7 = tmp1 >= tmp2
607
+ tmp8 = tmp4 < tmp2
608
+ tmp9 = tmp7 & tmp8
609
+ tmp10 = tmp8 == 0
610
+ tmp11 = tmp7 & tmp10
611
+ tmp12 = tmp1 - tmp2
612
+ tmp13 = tl.full([1], 16, tl.int32)
613
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
614
+ tmp15 = tmp4 - tmp2
615
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
616
+ tmp17 = tmp14 == tmp16
617
+ tmp18 = tmp11 & tmp17
618
+ tmp19 = tmp9 | tmp18
619
+ tmp20 = tmp6 | tmp19
620
+ mask_mod_output = tmp20
621
+
622
+
623
+ # apply mask for partial masked block
624
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
625
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
626
+ if not PRESCALE_QK:
627
+ post_mod_scores *= RCP_LN2
628
+ p = tl.math.exp2(post_mod_scores - lse)
629
+ # Compute dP and dS.
630
+ # NB reversed order to since V is transposed
631
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
632
+
633
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
634
+ ds = p * (dp - Di[:, None])
635
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
636
+ tmp21 = (ds)
637
+ grad_scores = tmp21
638
+
639
+
640
+ if not IS_DIVISIBLE:
641
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
642
+
643
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
644
+ if WRITE_DQ:
645
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
646
+
647
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
648
+ ds = grad_scores
649
+
650
+ if not IS_FULL_BLOCKS:
651
+ # (grads) apply mask for partially unmasked block
652
+ ds = tl.where(mask_mod_output, ds, 0.0)
653
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
654
+ ds = ds.to(MATMUL_PRECISION)
655
+ # Compute dQ.
656
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
657
+
658
+ return dq
659
+
660
+
661
+ @triton.jit
662
+ def bwd_dkdv_inner(
663
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
664
+ Q, DO, DELTA, LSE, # pointers
665
+ dk, dv, k, v,
666
+ off_z, off_hq, offs_n1, offs_m1,
667
+ stride_qm, stride_qd, stride_dom, stride_dod,
668
+ q_indices, sparse_q_num_blocks,
669
+ MATMUL_PRECISION,
670
+ IS_FULL_BLOCKS,
671
+ ):
672
+ PRESCALE_QK : tl.constexpr = False
673
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
674
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
675
+ WRITE_DQ : tl.constexpr = True
676
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
677
+ OUTPUT_MAX : tl.constexpr = False
678
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
679
+ IS_DIVISIBLE : tl.constexpr = False
680
+ SM_SCALE : tl.constexpr = 0.08838834764831845
681
+ GQA_SHARED_HEADS : tl.constexpr = 4
682
+ HAS_FULL_BLOCKS : tl.constexpr = True
683
+ QK_HEAD_DIM : tl.constexpr = 128
684
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
685
+ V_HEAD_DIM : tl.constexpr = 128
686
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
687
+ SAFE_HEAD_DIM : tl.constexpr = True
688
+ BLOCK_M1 : tl.constexpr = 64
689
+ BLOCK_N1 : tl.constexpr = 128
690
+ BLOCK_M2 : tl.constexpr = 128
691
+ BLOCK_N2 : tl.constexpr = 64
692
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
693
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
694
+ INDEX_DTYPE : tl.constexpr = tl.int32
695
+
696
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
697
+ RCP_LN2: tl.constexpr = 1.44269504
698
+ Q_LEN = ks0
699
+ KV_LEN = ks1
700
+
701
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
702
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
703
+
704
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
705
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
706
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
707
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
708
+
709
+ # The minimum is needed to handle the case where we run with a super large
710
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
711
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
712
+
713
+ for start_m in range(0, hi):
714
+ dk, dv = bwd_dkdv_block_mn(
715
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
716
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
717
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
718
+ stride_qm, stride_qd, stride_dom, stride_dod,
719
+ q_indices, sparse_q_num_blocks,
720
+ MATMUL_PRECISION, RCP_LN2,
721
+ IS_FULL_BLOCKS,
722
+ )
723
+ # Increment pointers.
724
+ offset = get_offset_for_next_block(
725
+ start_m, q_indices, sparse_q_num_blocks,
726
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
727
+ )
728
+
729
+ qT_ptrs += offset * stride_qm
730
+ do_ptrs += offset * stride_dom
731
+ offs_m1 += offset
732
+
733
+ return dk, dv
734
+
735
+
736
+ @triton.jit
737
+ def bwd_dkdv_block_mn(
738
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
739
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
740
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
741
+ stride_qm, stride_qd, stride_dom, stride_dod,
742
+ q_indices, sparse_q_num_blocks,
743
+ MATMUL_PRECISION, RCP_LN2,
744
+ IS_FULL_BLOCKS,
745
+ ):
746
+ PRESCALE_QK : tl.constexpr = False
747
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
748
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
749
+ WRITE_DQ : tl.constexpr = True
750
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
751
+ OUTPUT_MAX : tl.constexpr = False
752
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
753
+ IS_DIVISIBLE : tl.constexpr = False
754
+ SM_SCALE : tl.constexpr = 0.08838834764831845
755
+ GQA_SHARED_HEADS : tl.constexpr = 4
756
+ HAS_FULL_BLOCKS : tl.constexpr = True
757
+ QK_HEAD_DIM : tl.constexpr = 128
758
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
759
+ V_HEAD_DIM : tl.constexpr = 128
760
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
761
+ SAFE_HEAD_DIM : tl.constexpr = True
762
+ BLOCK_M1 : tl.constexpr = 64
763
+ BLOCK_N1 : tl.constexpr = 128
764
+ BLOCK_M2 : tl.constexpr = 128
765
+ BLOCK_N2 : tl.constexpr = 64
766
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
767
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
768
+ INDEX_DTYPE : tl.constexpr = tl.int32
769
+
770
+
771
+ # NB reversed order since Q is transposed
772
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
773
+ # Load LSE before computing qk to reduce pipeline stall.
774
+ if IS_DIVISIBLE:
775
+ lse = tl.load(LSE + offs_m1)
776
+ else:
777
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
778
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
779
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
780
+ if not PRESCALE_QK:
781
+ qkT *= SM_SCALE
782
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
783
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
784
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
785
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
786
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
787
+
788
+ pre_mod_scores = qkT
789
+ tmp22 = (qkT)
790
+ post_mod_scores = tmp22
791
+
792
+
793
+
794
+ if not IS_DIVISIBLE:
795
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
796
+
797
+ if not IS_FULL_BLOCKS:
798
+ tmp23 = (m)
799
+ tmp24 = tl.full([1], 0, tl.int32)
800
+ tmp25 = tmp23 < tmp24
801
+ tmp26 = (n)
802
+ tmp27 = tmp26 <= tmp23
803
+ tmp28 = tmp25 & tmp27
804
+ tmp29 = tmp23 >= tmp24
805
+ tmp30 = tmp26 < tmp24
806
+ tmp31 = tmp29 & tmp30
807
+ tmp32 = tmp30 == 0
808
+ tmp33 = tmp29 & tmp32
809
+ tmp34 = tmp23 - tmp24
810
+ tmp35 = tl.full([1], 16, tl.int32)
811
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
812
+ tmp37 = tmp26 - tmp24
813
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
814
+ tmp39 = tmp36 == tmp38
815
+ tmp40 = tmp33 & tmp39
816
+ tmp41 = tmp31 | tmp40
817
+ tmp42 = tmp28 | tmp41
818
+ mask_mod_output = tmp42
819
+
820
+ # (grads) apply mask for fully masked block
821
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
822
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
823
+ if not PRESCALE_QK:
824
+ post_mod_scores *= RCP_LN2
825
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
826
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
827
+ # Compute dV.
828
+ ppT = pT
829
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
830
+ if IS_DIVISIBLE:
831
+ Di = tl.load(DELTA + offs_m1)
832
+ else:
833
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
834
+ # Compute dP and dS.
835
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
836
+ dsT = pT * (dpT - Di[None, :])
837
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
838
+ tmp43 = (dsT)
839
+ grad_scores = tmp43
840
+
841
+
842
+
843
+ if not IS_DIVISIBLE:
844
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
845
+
846
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
847
+ if not WRITE_DQ:
848
+ idx_b = off_z
849
+ idx_h = off_hq
850
+ idx_m = m
851
+ idx_n = n
852
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
853
+
854
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
855
+ dsT = grad_scores
856
+ if not IS_FULL_BLOCKS:
857
+ # (grads) apply mask for partially unmasked block
858
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
861
+
862
+ return dk, dv
863
+
864
+ # Utility triton funcs
865
+ @triton.jit
866
+ def get_offset_for_next_block(
867
+ loop_iter, col_indices, total_blocks,
868
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
869
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
870
+ ):
871
+ if BLOCKS_ARE_CONTIGUOUS:
872
+ return BLOCK
873
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
874
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
875
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
876
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
877
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
878
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
879
+ return offset
880
+
881
+ @triton.jit
882
+ def get_bounded_indices(indices, max_len=None):
883
+ return indices % max_len if max_len is not None else indices
884
+
885
+ @triton.jit
886
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
887
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
888
+ return tl.load(block_ptr)
889
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
890
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
891
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
892
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
893
+ else:
894
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
895
+
896
+ @triton.jit
897
+ def load_checked_2d(
898
+ ptr,
899
+ offs_m,
900
+ offs_n,
901
+ stride_m,
902
+ stride_n,
903
+ IS_DIVISIBLE_M: tl.constexpr,
904
+ IS_DIVISIBLE_N: tl.constexpr,
905
+ M_LEN: tl.constexpr,
906
+ N_LEN: tl.constexpr,
907
+ ):
908
+ # Calculate final pointer if strides are provided
909
+ if stride_m is not None and stride_n is not None:
910
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
911
+
912
+ # Handle all masking cases
913
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
914
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
915
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
916
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
917
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
918
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
919
+ else: # Both divisible
920
+ return tl.load(ptr)
921
+ ''', device_str='cuda')
922
+
923
+
924
+ async_compile.wait(globals())
925
+ del async_compile
926
+
927
+ class Runner:
928
+ def __init__(self, partitions):
929
+ self.partitions = partitions
930
+
931
+ def recursively_apply_fns(self, fns):
932
+ new_callables = []
933
+ for fn, c in zip(fns, self.partitions):
934
+ new_callables.append(fn(c))
935
+ self.partitions = new_callables
936
+
937
+ def call(self, args):
938
+ primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args
939
+ args.clear()
940
+ s37 = primals_8
941
+ s0 = primals_9
942
+ assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
943
+ assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
944
+ assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
945
+ assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
946
+ assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
947
+ assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
948
+ assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
949
+ assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
950
+ assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
951
+ assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
952
+ assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
953
+ assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
954
+ assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
955
+ assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
956
+ assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
957
+ with torch.cuda._DeviceGuard(7):
958
+ torch.cuda.set_device(7)
959
+ buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
960
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
961
+ triton_per_fused_mul_0_xnumel = 32*s37
962
+ stream7 = get_raw_stream(7)
963
+ triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream7)
964
+ del getitem
965
+ del tangents_2
966
+ buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
967
+ buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
968
+ buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
969
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
970
+ stream7 = get_raw_stream(7)
971
+ triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream7)
972
+ del buf1
973
+ del getitem_1
974
+ del primals_10
975
+ del primals_11
976
+ del primals_12
977
+ del primals_13
978
+ del primals_14
979
+ del primals_15
980
+ del primals_16
981
+ del primals_2
982
+ del primals_4
983
+ del primals_6
984
+ del primals_7
985
+ del tangents_1
986
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, )
987
+
988
+ runner = Runner(partitions=[])
989
+ call = runner.call
990
+ recursively_apply_fns = runner.recursively_apply_fns
991
+
992
+
993
+ def benchmark_compiled_module(times=10, repeat=10):
994
+ from torch._dynamo.testing import rand_strided
995
+ from torch._inductor.utils import print_performance
996
+ primals_8 = 128
997
+ primals_9 = 128
998
+ primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
999
+ primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
1000
+ primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
1001
+ primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1002
+ primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1003
+ primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1004
+ primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1005
+ primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1006
+ primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1007
+ primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
1008
+ primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
1009
+ getitem = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
1010
+ getitem_1 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:7', dtype=torch.float32)
1011
+ tangents_1 = rand_strided((1, 32, 128, 128), (524288, 16384, 128, 1), device='cuda:7', dtype=torch.bfloat16)
1012
+ tangents_2 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:7', dtype=torch.float32)
1013
+ fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2])
1014
+ return print_performance(fn, times=times, repeat=repeat)
1015
+
1016
+
1017
+ if __name__ == "__main__":
1018
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1019
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/5z/c5zh2j5k5rlsr5zd4tfbvoplpwmbtizbrldjq2hw4nndmjztlcuy.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['3_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wp/cwpsdlduqm7qdaycs5k762qtpc4rkcap2vuyn5ripw5tdmb2n2ce.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1]
43
+ # %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1]
44
+ # %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1]
45
+ # %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
46
+ # %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
47
+ # %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg9_1]
48
+ # %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg6_1]
49
+ # %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg10_1]
50
+ # %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg11_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %buf2
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=2,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ GQA_SHARED_HEADS : tl.constexpr = 4
80
+ HAS_FULL_BLOCKS : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831845
82
+ SPLIT_KV : tl.constexpr = 32
83
+ QK_HEAD_DIM : tl.constexpr = 128
84
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
85
+ V_HEAD_DIM : tl.constexpr = 128
86
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
87
+ SAFE_HEAD_DIM : tl.constexpr = True
88
+ BLOCK_M : tl.constexpr = 256
89
+ SAFE_M_BOUNDARY : tl.constexpr = False
90
+ SAFE_N_BOUNDARY : tl.constexpr = True
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
93
+ USE_TMA : tl.constexpr = False
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ M = arg_M
99
+ L = arg_L
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ # Q: Query, K: Key, V: Value
107
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
108
+ # M: Number of queries, N: Number of keys/values
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
112
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
113
+ # (Modifiable) Config options:
114
+ # SPLIT_KV: number of blocks K & V are split into
115
+ # TILE_KV: length of each local KV split
116
+ # BLOCK_M: block size that Q is padded along seqlen dim.
117
+ # BLOCK_N: block size of K & V along N dimension.
118
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
119
+ #
120
+ # change of base out of the loop
121
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
122
+ # is not masked out? If so, we can skip an extra safety check
123
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
124
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
125
+
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
127
+ #
128
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
129
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
130
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
131
+ #
132
+ #
133
+ # Output: ACC output accumulated across local KV split.
134
+
135
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
136
+
137
+ # Define Q Strides
138
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
139
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
140
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
141
+ stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
142
+ stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
143
+
144
+
145
+ Z = 1
146
+ ZKV = 1
147
+ HKV = 8
148
+ G: tl.constexpr = GQA_SHARED_HEADS
149
+ HQ = HKV * G
150
+ Q_LEN = ks0
151
+ KV_LEN = ks1
152
+
153
+ MATMUL_PRECISION = Q.dtype.element_ty
154
+
155
+ # Make sure each split is a multiple of BLOCK_N
156
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
157
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
158
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
159
+
160
+ off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
161
+ off_zkv = off_z % ZKV
162
+ off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
163
+ off_t = tl.program_id(1).to(INDEX_DTYPE)
164
+
165
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
166
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
167
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
168
+
169
+ K = K + k_offset
170
+ V = V + v_offset
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_z % SPARSE_Z
176
+ sparse_idx_h = off_hkv % SPARSE_HQ
177
+
178
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
179
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
180
+
181
+ # initialize pointer to m and l
182
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
183
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
184
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
185
+
186
+ # initialize offsets
187
+ tl.device_assert(BLOCK_M % G == 0)
188
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
189
+ off_g = tl.arange(0, G) # [G]
190
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
191
+ offs_hq = offs_g + off_hkv * G
192
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
193
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
194
+ offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
195
+ offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
196
+
197
+ # Get HZ offsets for KV_NUM_BLKS and KV_IDX
198
+ stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
199
+ sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
200
+ stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
201
+ sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
202
+
203
+ # Calculate KV blocks that belong this CTA.
204
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
205
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
206
+
207
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
208
+
209
+ if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
210
+ q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
211
+ elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
212
+ q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
213
+ elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
214
+ q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
215
+ else:
216
+ q = tl.load(Q + q_offset + q_range)
217
+
218
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
219
+
220
+
221
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
222
+ # find first kv block we are loading and the number of blocks we are loading
223
+ # Offset the kv_indices tensor by the correct batch and head
224
+ kv_indices = KV_IDX + sparse_idx_hz_offset
225
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
226
+ MAX_KV_IDX = 1
227
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
228
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
229
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
230
+ # first kv block we're loading
231
+
232
+ # last valid block according to sparse mask
233
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
234
+
235
+ offs_n = tl.arange(0, BLOCK_N) + off_n
236
+
237
+ desc_k = None
238
+ desc_v = None
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
242
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
243
+ # accumulatd values
244
+ acc, l_i, m_i,
245
+ #offsets
246
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
247
+ off_n,
248
+ #block sparse data
249
+ kv_indices, kv_num_blocks,
250
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=False,
254
+ )
255
+
256
+
257
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
258
+ # We know these blocks are guaranteed to be "full", so we don't need to
259
+ # apply mask_mod to them - only score_mod
260
+ if HAS_FULL_BLOCKS:
261
+ kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
262
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
263
+ # Assign full block in a reverse order for off_t. Prioritize the last CTA.
264
+ block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
265
+ block_n_end = block_n_start + TILE_KV_MULTIPLE
266
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
267
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
268
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
269
+
270
+ # last valid block according to sparse mask
271
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
272
+
273
+ offs_n = tl.arange(0, BLOCK_N) + off_n
274
+
275
+ acc, l_i, m_i = forward_inner(
276
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
277
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
278
+ # accumulatd values
279
+ acc, l_i, m_i,
280
+ #offsets
281
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
282
+ off_n,
283
+ #block sparse data
284
+ kv_indices, kv_num_blocks,
285
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
286
+ MATMUL_PRECISION,
287
+ stride_kk, stride_kn, stride_vn, stride_vk,
288
+ IS_FULL_BLOCKS=True,
289
+ )
290
+
291
+ m_offset = off_t * stride_mt + off_z * stride_mz
292
+ l_offset = off_t * stride_lt + off_z * stride_lz
293
+
294
+ M_block_ptr = tl.make_block_ptr(
295
+ base=M + m_offset,
296
+ shape=(G, Q_LEN), # (G, M)
297
+ strides=(stride_mh, stride_mm),
298
+ offsets=(off_hkv*G, 0),
299
+ block_shape=(G, BLOCK_M_PER_HQ),
300
+ order=(1, 0)
301
+ )
302
+ L_block_ptr = tl.make_block_ptr(
303
+ base=L + l_offset,
304
+ shape=(G, Q_LEN), # (G, M)
305
+ strides=(stride_lh, stride_lm),
306
+ offsets=(off_hkv*G, 0),
307
+ block_shape=(G, BLOCK_M_PER_HQ),
308
+ order=(1, 0)
309
+ )
310
+
311
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
312
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
313
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
314
+ if SAFE_M_BOUNDARY:
315
+ tl.store(M_block_ptr, m_i)
316
+ tl.store(L_block_ptr, l_i)
317
+ else:
318
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
319
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
320
+
321
+ # -- store output
322
+ idx_z = off_z
323
+ idx_t = off_t
324
+ idx_hq = off_hkv*G + off_g[:, None, None]
325
+ idx_m = off_m[None, :, None]
326
+ idx_d = offs_vd[None, None, :]
327
+
328
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
329
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
330
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
331
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
332
+
333
+
334
+ # Utility triton funcs
335
+ @triton.jit
336
+ def get_offset_for_next_block(
337
+ loop_iter, col_indices, total_blocks,
338
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
339
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
340
+ ):
341
+ if BLOCKS_ARE_CONTIGUOUS:
342
+ return BLOCK
343
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
344
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
345
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
346
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
347
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
348
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
349
+ return offset
350
+
351
+ @triton.jit
352
+ def get_bounded_indices(indices, max_len=None):
353
+ return indices % max_len if max_len is not None else indices
354
+
355
+ @triton.jit
356
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
357
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
358
+ return tl.load(block_ptr)
359
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
360
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
361
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
362
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
363
+ else:
364
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
365
+
366
+ @triton.jit
367
+ def load_checked_2d(
368
+ ptr,
369
+ offs_m,
370
+ offs_n,
371
+ stride_m,
372
+ stride_n,
373
+ IS_DIVISIBLE_M: tl.constexpr,
374
+ IS_DIVISIBLE_N: tl.constexpr,
375
+ M_LEN: tl.constexpr,
376
+ N_LEN: tl.constexpr,
377
+ ):
378
+ # Calculate final pointer if strides are provided
379
+ if stride_m is not None and stride_n is not None:
380
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
381
+
382
+ # Handle all masking cases
383
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
384
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
385
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
386
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
387
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
388
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
389
+ else: # Both divisible
390
+ return tl.load(ptr)
391
+
392
+
393
+ # Common Imports
394
+ @triton.jit
395
+ def forward_block_mn(
396
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
397
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
398
+ # accumulated values
399
+ acc, l_i, m_i,
400
+ # Offsets
401
+ off_z, off_h, offs_m, offs_n,
402
+ # Offsets needed for TMA loads
403
+ kv_start,
404
+ kv_offset,
405
+ MATMUL_PRECISION, RCP_LN2,
406
+ # Strides for K and V
407
+ stride_kk, stride_kn, stride_vn, stride_vk,
408
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
409
+
410
+ ):
411
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
412
+ PRESCALE_QK : tl.constexpr = False
413
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
414
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
415
+ WRITE_DQ : tl.constexpr = True
416
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
417
+ OUTPUT_MAX : tl.constexpr = False
418
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
419
+ IS_DIVISIBLE : tl.constexpr = False
420
+ GQA_SHARED_HEADS : tl.constexpr = 4
421
+ HAS_FULL_BLOCKS : tl.constexpr = True
422
+ SM_SCALE : tl.constexpr = 0.08838834764831845
423
+ SPLIT_KV : tl.constexpr = 32
424
+ QK_HEAD_DIM : tl.constexpr = 128
425
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
426
+ V_HEAD_DIM : tl.constexpr = 128
427
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
428
+ SAFE_HEAD_DIM : tl.constexpr = True
429
+ BLOCK_M : tl.constexpr = 256
430
+ SAFE_M_BOUNDARY : tl.constexpr = False
431
+ SAFE_N_BOUNDARY : tl.constexpr = True
432
+ BLOCK_N : tl.constexpr = 64
433
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
434
+ USE_TMA : tl.constexpr = False
435
+ INDEX_DTYPE : tl.constexpr = tl.int32
436
+
437
+
438
+ # -- load k --
439
+ # NB reversed order to since K is transposed
440
+ kv_base_offset = kv_start + kv_offset
441
+
442
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
443
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
444
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
445
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
446
+
447
+ k = tl.trans(k)
448
+ # -- compute qk ---
449
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
450
+ if not PRESCALE_QK:
451
+ qk *= SM_SCALE
452
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
453
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
454
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
455
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
456
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
457
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
458
+
459
+ tmp0 = (qk)
460
+ post_mod_scores = tmp0
461
+
462
+
463
+ if CHECK_BLOCK_BOUNDARY:
464
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
465
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
466
+
467
+ if not IS_FULL_BLOCKS:
468
+ tmp1 = (m)
469
+ tmp2 = tl.full([1], 0, tl.int32)
470
+ tmp3 = tmp1 < tmp2
471
+ tmp4 = (n)
472
+ tmp5 = tmp4 <= tmp1
473
+ tmp6 = tmp3 & tmp5
474
+ tmp7 = tmp1 >= tmp2
475
+ tmp8 = tmp4 < tmp2
476
+ tmp9 = tmp7 & tmp8
477
+ tmp10 = tmp8 == 0
478
+ tmp11 = tmp7 & tmp10
479
+ tmp12 = tmp1 - tmp2
480
+ tmp13 = tl.full([1], 16, tl.int32)
481
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
482
+ tmp15 = tmp4 - tmp2
483
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
484
+ tmp17 = tmp14 == tmp16
485
+ tmp18 = tmp11 & tmp17
486
+ tmp19 = tmp9 | tmp18
487
+ tmp20 = tmp6 | tmp19
488
+ mask_mod_output = tmp20
489
+
490
+
491
+ if CHECK_BLOCK_BOUNDARY:
492
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
493
+ # apply mask for partially unmasked blocks
494
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
495
+
496
+ if not PRESCALE_QK:
497
+ post_mod_scores *= RCP_LN2
498
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
499
+
500
+ # -- compute scaling constant ---
501
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
502
+ if not ROWS_GUARANTEED_SAFE:
503
+ masked_out_rows = (m_ij == float("-inf"))
504
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
505
+ else:
506
+ m_ij_masked = m_ij
507
+
508
+ alpha = tl.math.exp2(m_i - m_ij_masked)
509
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
510
+
511
+ # NB: l_i update is pulled up here since it's a bit faster
512
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
513
+ # m_ij
514
+ l_i = l_i * alpha + tl.sum(p, 1)
515
+ # # -- scale and update acc --
516
+ acc = acc * alpha[:, None]
517
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
518
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
519
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
520
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
521
+
522
+ # -- update m_i
523
+ m_i = m_ij
524
+
525
+ return acc, l_i, m_i
526
+
527
+ @triton.jit
528
+ def forward_inner(
529
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
530
+ q, K, V,
531
+ desc_k, desc_v, Q_LEN, KV_LEN,
532
+ # accumulated values
533
+ acc, l_i, m_i,
534
+ # Offsets used as inputs to score_mod & mask_mod
535
+ # of size [BLOCK_M, BLOCK_N] or scalar.
536
+ off_z, off_h, offs_m, offs_n,
537
+ # Offsets needed for TMA loads
538
+ kv_start,
539
+ # blocksparse data
540
+ kv_indices, kv_num_blocks,
541
+ # start kv and end kv block
542
+ block_n_start, block_n_end,
543
+ MATMUL_PRECISION,
544
+ # Strides for K and V
545
+ stride_kk, stride_kn, stride_vn, stride_vk,
546
+ IS_FULL_BLOCKS,
547
+ ):
548
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
549
+ PRESCALE_QK : tl.constexpr = False
550
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
551
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
552
+ WRITE_DQ : tl.constexpr = True
553
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
554
+ OUTPUT_MAX : tl.constexpr = False
555
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
556
+ IS_DIVISIBLE : tl.constexpr = False
557
+ GQA_SHARED_HEADS : tl.constexpr = 4
558
+ HAS_FULL_BLOCKS : tl.constexpr = True
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ SPLIT_KV : tl.constexpr = 32
561
+ QK_HEAD_DIM : tl.constexpr = 128
562
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
563
+ V_HEAD_DIM : tl.constexpr = 128
564
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
565
+ SAFE_HEAD_DIM : tl.constexpr = True
566
+ BLOCK_M : tl.constexpr = 256
567
+ SAFE_M_BOUNDARY : tl.constexpr = False
568
+ SAFE_N_BOUNDARY : tl.constexpr = True
569
+ BLOCK_N : tl.constexpr = 64
570
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
571
+ USE_TMA : tl.constexpr = False
572
+ INDEX_DTYPE : tl.constexpr = tl.int32
573
+
574
+
575
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+
578
+ if PRESCALE_QK:
579
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
580
+
581
+ kv_offset = 0
582
+
583
+ # loop over k, v and update accumulator until block_n_end
584
+ for start_n in range(block_n_start, block_n_end):
585
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
586
+ if IS_DIVISIBLE:
587
+ acc, l_i, m_i = forward_block_mn(
588
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
589
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
590
+ # accumulated values
591
+ acc, l_i, m_i,
592
+ # Offsets
593
+ off_z, off_h, offs_m, offs_n,
594
+ # Offsets needed for TMA loads
595
+ kv_start,
596
+ kv_offset,
597
+ MATMUL_PRECISION, RCP_LN2,
598
+ # Strides for K and V
599
+ stride_kk, stride_kn, stride_vn, stride_vk,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ else:
603
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
604
+ # it's on par or slightly faster than only applying to the last block in fwd.
605
+ # However, we choose different strategy for bwd, where we only apply mod & mask
606
+ # to the last block because it's faster a lot.
607
+ acc, l_i, m_i = forward_block_mn(
608
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
609
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
610
+ # accumulated values
611
+ acc, l_i, m_i,
612
+ # Offsets
613
+ off_z, off_h, offs_m, offs_n,
614
+ # Offsets needed for TMA loads
615
+ kv_start,
616
+ kv_offset,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ # Strides for K and V
619
+ stride_kk, stride_kn, stride_vn, stride_vk,
620
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
621
+ )
622
+
623
+
624
+
625
+ offset = get_offset_for_next_block(
626
+ start_n, kv_indices, kv_num_blocks,
627
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
628
+ )
629
+
630
+ offs_n = offs_n + offset
631
+ kv_offset += offset
632
+
633
+
634
+ return acc, l_i, m_i
635
+ ''', device_str='cuda')
636
+
637
+
638
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/3p/c3pdrhexk4rwol7f5l5vh7n543dj6piq6gw5k66g2p4vlyhopnop.py
639
+ # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
640
+ # Source node to ATen node mapping:
641
+ # flex_attention => flex_attention
642
+ # lse_scaled => mul_9
643
+ # Graph fragment:
644
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
645
+ # %buf4 : Tensor = PlaceHolder[target=buf4]
646
+ # %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
647
+ # %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
648
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
649
+ # %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
650
+ # return %buf5,%buf7,%mul_9
651
+ triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', '''
652
+ import triton
653
+ import triton.language as tl
654
+
655
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
656
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
657
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
658
+ triton_helpers.set_driver_to_gpu()
659
+
660
+ @triton_heuristics.persistent_reduction(
661
+ size_hints={'x': 2048, 'r0_': 32},
662
+ reduction_hint=ReductionHint.DEFAULT,
663
+ filename=__file__,
664
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
665
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
666
+ )
667
+ @triton.jit
668
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
669
+ r0_numel = 32
670
+ R0_BLOCK: tl.constexpr = 32
671
+ rnumel = r0_numel
672
+ RBLOCK: tl.constexpr = R0_BLOCK
673
+ xoffset = tl.program_id(0) * XBLOCK
674
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
675
+ xmask = xindex < xnumel
676
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
677
+ r0_offset = 0
678
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
679
+ roffset = r0_offset
680
+ rindex = r0_index
681
+ r0_1 = r0_index
682
+ x0 = xindex
683
+ x2 = (xindex % ks0)
684
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
685
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
686
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
687
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
688
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
689
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
690
+ tmp6 = float("-inf")
691
+ tmp7 = tmp4 == tmp6
692
+ tmp8 = tmp0 - tmp4
693
+ tmp9 = 0.0
694
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
695
+ tmp11 = libdevice.exp2(tmp10)
696
+ tmp12 = tmp5 * tmp11
697
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
698
+ tmp15 = tl.where(xmask, tmp13, 0)
699
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
700
+ tmp17 = 1.0
701
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
702
+ tmp19 = libdevice.log2(tmp18)
703
+ tmp20 = tmp19 + tmp4
704
+ tmp21 = 0.6931471805599453
705
+ tmp22 = tmp20 * tmp21
706
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
707
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
708
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
709
+ ''', device_str='cuda')
710
+
711
+
712
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/so/csovo4ylg3oadsjdsgqzxojai65n6xwa5n7magfct6j23x35wkeq.py
713
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
714
+ # Source node to ATen node mapping:
715
+ # flex_attention => flex_attention, getitem
716
+ # Graph fragment:
717
+ # %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf2]
718
+ # %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
719
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
720
+ # %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf8]
721
+ # %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
722
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
723
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {})
724
+ # return %buf8,%getitem
725
+ triton_per_fused_2 = async_compile.triton('triton_per_fused_2', '''
726
+ import triton
727
+ import triton.language as tl
728
+
729
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
730
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
731
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
732
+ triton_helpers.set_driver_to_gpu()
733
+
734
+ @triton_heuristics.persistent_reduction(
735
+ size_hints={'x': 262144, 'r0_': 32},
736
+ reduction_hint=ReductionHint.DEFAULT,
737
+ filename=__file__,
738
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
739
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
740
+ )
741
+ @triton.jit
742
+ def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
743
+ r0_numel = 32
744
+ R0_BLOCK: tl.constexpr = 32
745
+ rnumel = r0_numel
746
+ RBLOCK: tl.constexpr = R0_BLOCK
747
+ xoffset = tl.program_id(0) * XBLOCK
748
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
749
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
750
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
751
+ r0_offset = 0
752
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
753
+ roffset = r0_offset
754
+ rindex = r0_index
755
+ r0_2 = r0_index
756
+ x5 = xindex
757
+ x1 = xindex // 128
758
+ x0 = (xindex % 128)
759
+ x3 = ((xindex // 128) % ks0)
760
+ x4 = xindex // ks1
761
+ tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
762
+ tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
763
+ tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
764
+ tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
765
+ tmp2 = float("-inf")
766
+ tmp3 = tmp1 == tmp2
767
+ tmp5 = tmp4 - tmp1
768
+ tmp6 = 0.0
769
+ tmp7 = tl.where(tmp3, tmp6, tmp5)
770
+ tmp8 = libdevice.exp2(tmp7)
771
+ tmp9 = tmp0 * tmp8
772
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
773
+ tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
774
+ tmp14 = 1.0
775
+ tmp15 = tl.where(tmp3, tmp14, tmp13)
776
+ tmp16 = (tmp12 / tmp15)
777
+ tmp17 = tmp16.to(tl.float32)
778
+ tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
779
+ ''', device_str='cuda')
780
+
781
+
782
+ async_compile.wait(globals())
783
+ del async_compile
784
+
785
+ class Runner:
786
+ def __init__(self, partitions):
787
+ self.partitions = partitions
788
+
789
+ def recursively_apply_fns(self, fns):
790
+ new_callables = []
791
+ for fn, c in zip(fns, self.partitions):
792
+ new_callables.append(fn(c))
793
+ self.partitions = new_callables
794
+
795
+ def call(self, args):
796
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
797
+ args.clear()
798
+ s50 = arg0_1
799
+ s0 = arg2_1
800
+ s43 = arg4_1
801
+ s37 = arg7_1
802
+ s71 = arg8_1
803
+ assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
804
+ assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
805
+ assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
806
+ assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
807
+ assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
808
+ assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
809
+ assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
810
+ assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
811
+ assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
812
+ assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
813
+ assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
814
+ with torch.cuda._DeviceGuard(7):
815
+ torch.cuda.set_device(7)
816
+ buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
817
+ buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
818
+ buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32)
819
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
820
+ stream7 = get_raw_stream(7)
821
+ triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream7)
822
+ del arg10_1
823
+ del arg11_1
824
+ del arg1_1
825
+ del arg3_1
826
+ del arg5_1
827
+ del arg6_1
828
+ del arg9_1
829
+ buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32)
830
+ buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
831
+ buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
832
+ # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
833
+ triton_per_fused_mul_1_xnumel = 32*s37
834
+ stream7 = get_raw_stream(7)
835
+ triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream7)
836
+ del buf1
837
+ ps0 = 128*s37
838
+ buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
839
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
840
+ triton_per_fused_2_xnumel = 4096*s37
841
+ stream7 = get_raw_stream(7)
842
+ triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream7)
843
+ del buf0
844
+ del buf2
845
+ del buf5
846
+ del buf7
847
+ return (buf9, buf10, )
848
+
849
+ runner = Runner(partitions=[])
850
+ call = runner.call
851
+ recursively_apply_fns = runner.recursively_apply_fns
852
+
853
+
854
+ def benchmark_compiled_module(times=10, repeat=10):
855
+ from torch._dynamo.testing import rand_strided
856
+ from torch._inductor.utils import print_performance
857
+ arg0_1 = 48
858
+ arg1_1 = rand_strided((1, 32, 48, 128), (196608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
859
+ arg2_1 = 48
860
+ arg3_1 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
861
+ arg4_1 = 48
862
+ arg5_1 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
863
+ arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
864
+ arg7_1 = 48
865
+ arg8_1 = 48
866
+ arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
867
+ arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
868
+ arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
869
+ arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
870
+ arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
871
+ arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
872
+ arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
873
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
874
+ return print_performance(fn, times=times, repeat=repeat)
875
+
876
+
877
+ if __name__ == "__main__":
878
+ from torch._inductor.wrapper_benchmark import compiled_module_main
879
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/66/c66vzxeqjq4tywx6ezsscs3u3rb6yxac26rzmkwrlzc3kmkcnhlf.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4521984, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1130496, 128, 1024, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1130496, 128, 1024, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4521984, 141312, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4521984, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1130496, 128, 1024, 1
101
+
102
+ ZQ = 1
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 1104
106
+ ZKV = 1
107
+ KV_LEN = 1104
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 1
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 9
148
+ stride_kv_idx_h = 81
149
+ stride_kv_idx_m = 9
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 9
245
+ stride_q_idx_h = 81
246
+ stride_q_idx_n = 9
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 141312*off_hkv + 1130496*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831845
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 1104
385
+ KV_LEN = 1104
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831845
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = (m)
480
+ tmp2 = tl.full([1], 0, tl.int32)
481
+ tmp3 = tmp1 < tmp2
482
+ tmp4 = (n)
483
+ tmp5 = tmp4 <= tmp1
484
+ tmp6 = tmp3 & tmp5
485
+ tmp7 = tmp1 >= tmp2
486
+ tmp8 = tmp4 < tmp2
487
+ tmp9 = tmp7 & tmp8
488
+ tmp10 = tmp8 == 0
489
+ tmp11 = tmp7 & tmp10
490
+ tmp12 = tmp1 - tmp2
491
+ tmp13 = tl.full([1], 16, tl.int32)
492
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
493
+ tmp15 = tmp4 - tmp2
494
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
495
+ tmp17 = tmp14 == tmp16
496
+ tmp18 = tmp11 & tmp17
497
+ tmp19 = tmp9 | tmp18
498
+ tmp20 = tmp6 | tmp19
499
+ mask_mod_output = tmp20
500
+
501
+
502
+ # apply mask for partial masked block
503
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
504
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
505
+ if not PRESCALE_QK:
506
+ post_mod_scores *= RCP_LN2
507
+ p = tl.math.exp2(post_mod_scores - lse)
508
+ # Compute dP and dS.
509
+ # NB reversed order to since V is transposed
510
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
511
+
512
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
513
+ ds = p * (dp - Di[:, None])
514
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
515
+ tmp21 = (ds)
516
+ grad_scores = tmp21
517
+
518
+
519
+ if not IS_DIVISIBLE:
520
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
521
+
522
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
523
+ if WRITE_DQ:
524
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
525
+
526
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
527
+ ds = grad_scores
528
+
529
+ if not IS_FULL_BLOCKS:
530
+ # (grads) apply mask for partially unmasked block
531
+ ds = tl.where(mask_mod_output, ds, 0.0)
532
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533
+ ds = ds.to(MATMUL_PRECISION)
534
+ # Compute dQ.
535
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
536
+
537
+ return dq
538
+
539
+
540
+ @triton.jit
541
+ def bwd_dkdv_inner(
542
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
543
+ Q, DO, DELTA, LSE, # pointers
544
+ dk, dv, k, v,
545
+ off_z, off_hq, offs_n1, offs_m1,
546
+ stride_qm, stride_qd, stride_dom, stride_dod,
547
+ q_indices, sparse_q_num_blocks,
548
+ MATMUL_PRECISION,
549
+ IS_FULL_BLOCKS,
550
+ ):
551
+ PRESCALE_QK : tl.constexpr = False
552
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
553
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
554
+ WRITE_DQ : tl.constexpr = True
555
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
556
+ OUTPUT_MAX : tl.constexpr = False
557
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
558
+ IS_DIVISIBLE : tl.constexpr = False
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ GQA_SHARED_HEADS : tl.constexpr = 4
561
+ HAS_FULL_BLOCKS : tl.constexpr = True
562
+ QK_HEAD_DIM : tl.constexpr = 128
563
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
564
+ V_HEAD_DIM : tl.constexpr = 128
565
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
566
+ SAFE_HEAD_DIM : tl.constexpr = True
567
+ BLOCK_M1 : tl.constexpr = 64
568
+ BLOCK_N1 : tl.constexpr = 128
569
+ BLOCK_M2 : tl.constexpr = 128
570
+ BLOCK_N2 : tl.constexpr = 64
571
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
572
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
573
+ INDEX_DTYPE : tl.constexpr = tl.int32
574
+
575
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+ Q_LEN = 1104
578
+ KV_LEN = 1104
579
+
580
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
581
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
582
+
583
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
584
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
585
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
586
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
587
+
588
+ # The minimum is needed to handle the case where we run with a super large
589
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
590
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
591
+
592
+ for start_m in range(0, hi):
593
+ dk, dv = bwd_dkdv_block_mn(
594
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
595
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
596
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
597
+ stride_qm, stride_qd, stride_dom, stride_dod,
598
+ q_indices, sparse_q_num_blocks,
599
+ MATMUL_PRECISION, RCP_LN2,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ # Increment pointers.
603
+ offset = get_offset_for_next_block(
604
+ start_m, q_indices, sparse_q_num_blocks,
605
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
606
+ )
607
+
608
+ qT_ptrs += offset * stride_qm
609
+ do_ptrs += offset * stride_dom
610
+ offs_m1 += offset
611
+
612
+ return dk, dv
613
+
614
+
615
+ @triton.jit
616
+ def bwd_dkdv_block_mn(
617
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
618
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
619
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
620
+ stride_qm, stride_qd, stride_dom, stride_dod,
621
+ q_indices, sparse_q_num_blocks,
622
+ MATMUL_PRECISION, RCP_LN2,
623
+ IS_FULL_BLOCKS,
624
+ ):
625
+ PRESCALE_QK : tl.constexpr = False
626
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
627
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
628
+ WRITE_DQ : tl.constexpr = True
629
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
630
+ OUTPUT_MAX : tl.constexpr = False
631
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
632
+ IS_DIVISIBLE : tl.constexpr = False
633
+ SM_SCALE : tl.constexpr = 0.08838834764831845
634
+ GQA_SHARED_HEADS : tl.constexpr = 4
635
+ HAS_FULL_BLOCKS : tl.constexpr = True
636
+ QK_HEAD_DIM : tl.constexpr = 128
637
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
638
+ V_HEAD_DIM : tl.constexpr = 128
639
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
640
+ SAFE_HEAD_DIM : tl.constexpr = True
641
+ BLOCK_M1 : tl.constexpr = 64
642
+ BLOCK_N1 : tl.constexpr = 128
643
+ BLOCK_M2 : tl.constexpr = 128
644
+ BLOCK_N2 : tl.constexpr = 64
645
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
646
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
647
+ INDEX_DTYPE : tl.constexpr = tl.int32
648
+
649
+
650
+ # NB reversed order since Q is transposed
651
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
652
+ # Load LSE before computing qk to reduce pipeline stall.
653
+ if IS_DIVISIBLE:
654
+ lse = tl.load(LSE + offs_m1)
655
+ else:
656
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
657
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
658
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
659
+ if not PRESCALE_QK:
660
+ qkT *= SM_SCALE
661
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
662
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
663
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
664
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
665
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
666
+
667
+ pre_mod_scores = qkT
668
+ tmp22 = (qkT)
669
+ post_mod_scores = tmp22
670
+
671
+
672
+
673
+ if not IS_DIVISIBLE:
674
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
675
+
676
+ if not IS_FULL_BLOCKS:
677
+ tmp23 = (m)
678
+ tmp24 = tl.full([1], 0, tl.int32)
679
+ tmp25 = tmp23 < tmp24
680
+ tmp26 = (n)
681
+ tmp27 = tmp26 <= tmp23
682
+ tmp28 = tmp25 & tmp27
683
+ tmp29 = tmp23 >= tmp24
684
+ tmp30 = tmp26 < tmp24
685
+ tmp31 = tmp29 & tmp30
686
+ tmp32 = tmp30 == 0
687
+ tmp33 = tmp29 & tmp32
688
+ tmp34 = tmp23 - tmp24
689
+ tmp35 = tl.full([1], 16, tl.int32)
690
+ tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
691
+ tmp37 = tmp26 - tmp24
692
+ tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
693
+ tmp39 = tmp36 == tmp38
694
+ tmp40 = tmp33 & tmp39
695
+ tmp41 = tmp31 | tmp40
696
+ tmp42 = tmp28 | tmp41
697
+ mask_mod_output = tmp42
698
+
699
+ # (grads) apply mask for fully masked block
700
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
701
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
702
+ if not PRESCALE_QK:
703
+ post_mod_scores *= RCP_LN2
704
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
705
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
706
+ # Compute dV.
707
+ ppT = pT
708
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
709
+ if IS_DIVISIBLE:
710
+ Di = tl.load(DELTA + offs_m1)
711
+ else:
712
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
713
+ # Compute dP and dS.
714
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
715
+ dsT = pT * (dpT - Di[None, :])
716
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
717
+ tmp43 = (dsT)
718
+ grad_scores = tmp43
719
+
720
+
721
+
722
+ if not IS_DIVISIBLE:
723
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
724
+
725
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
726
+ if not WRITE_DQ:
727
+ idx_b = off_z
728
+ idx_h = off_hq
729
+ idx_m = m
730
+ idx_n = n
731
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
732
+
733
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
734
+ dsT = grad_scores
735
+ if not IS_FULL_BLOCKS:
736
+ # (grads) apply mask for partially unmasked block
737
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
738
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
739
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
740
+
741
+ return dk, dv
742
+
743
+ # Utility triton funcs
744
+ @triton.jit
745
+ def get_offset_for_next_block(
746
+ loop_iter, col_indices, total_blocks,
747
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
748
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
749
+ ):
750
+ if BLOCKS_ARE_CONTIGUOUS:
751
+ return BLOCK
752
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
753
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
754
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
755
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
756
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
757
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
758
+ return offset
759
+
760
+ @triton.jit
761
+ def get_bounded_indices(indices, max_len=None):
762
+ return indices % max_len if max_len is not None else indices
763
+
764
+ @triton.jit
765
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
766
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
767
+ return tl.load(block_ptr)
768
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
769
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
770
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
771
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
772
+ else:
773
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
774
+
775
+ @triton.jit
776
+ def load_checked_2d(
777
+ ptr,
778
+ offs_m,
779
+ offs_n,
780
+ stride_m,
781
+ stride_n,
782
+ IS_DIVISIBLE_M: tl.constexpr,
783
+ IS_DIVISIBLE_N: tl.constexpr,
784
+ M_LEN: tl.constexpr,
785
+ N_LEN: tl.constexpr,
786
+ ):
787
+ # Calculate final pointer if strides are provided
788
+ if stride_m is not None and stride_n is not None:
789
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
790
+
791
+ # Handle all masking cases
792
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
793
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
794
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
795
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
796
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
797
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
798
+ else: # Both divisible
799
+ return tl.load(ptr)
progress/SpecForge/cache/compiled_kernels/6b/c6b364ingwlftc5camjox4wdd5z5l4famigf6ojv4cyji4ju37fy.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['5_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/za/czavreibvu56vq4htyyxtbexpf3r3xsyhsclf2t4oq5z37c2h7e5.py
38
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
39
+ # Source node to ATen node mapping:
40
+ # flex_attention => flex_attention
41
+ # Graph fragment:
42
+ # %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1]
43
+ # %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1]
44
+ # %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1]
45
+ # %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
46
+ # %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
47
+ # %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg9_1]
48
+ # %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg6_1]
49
+ # %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg10_1]
50
+ # %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg11_1]
51
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
52
+ # return %buf2
53
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+
61
+ @triton_heuristics.template(
62
+
63
+ num_stages=3,
64
+ num_warps=2,
65
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
67
+
68
+ )
69
+ @triton.jit
70
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
71
+ PRESCALE_QK : tl.constexpr = False
72
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
73
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
74
+ WRITE_DQ : tl.constexpr = True
75
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
76
+ OUTPUT_MAX : tl.constexpr = False
77
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
78
+ IS_DIVISIBLE : tl.constexpr = False
79
+ GQA_SHARED_HEADS : tl.constexpr = 4
80
+ HAS_FULL_BLOCKS : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831845
82
+ SPLIT_KV : tl.constexpr = 32
83
+ QK_HEAD_DIM : tl.constexpr = 128
84
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
85
+ V_HEAD_DIM : tl.constexpr = 128
86
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
87
+ SAFE_HEAD_DIM : tl.constexpr = True
88
+ BLOCK_M : tl.constexpr = 512
89
+ SAFE_M_BOUNDARY : tl.constexpr = False
90
+ SAFE_N_BOUNDARY : tl.constexpr = True
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
93
+ USE_TMA : tl.constexpr = False
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ M = arg_M
99
+ L = arg_L
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ # Q: Query, K: Key, V: Value
107
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
108
+ # M: Number of queries, N: Number of keys/values
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
112
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
113
+ # (Modifiable) Config options:
114
+ # SPLIT_KV: number of blocks K & V are split into
115
+ # TILE_KV: length of each local KV split
116
+ # BLOCK_M: block size that Q is padded along seqlen dim.
117
+ # BLOCK_N: block size of K & V along N dimension.
118
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
119
+ #
120
+ # change of base out of the loop
121
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
122
+ # is not masked out? If so, we can skip an extra safety check
123
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
124
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
125
+
126
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
127
+ #
128
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
129
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
130
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
131
+ #
132
+ #
133
+ # Output: ACC output accumulated across local KV split.
134
+
135
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
136
+
137
+ # Define Q Strides
138
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
139
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
140
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
141
+ stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
142
+ stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
143
+
144
+
145
+ Z = 1
146
+ ZKV = 1
147
+ HKV = 8
148
+ G: tl.constexpr = GQA_SHARED_HEADS
149
+ HQ = HKV * G
150
+ Q_LEN = ks0
151
+ KV_LEN = ks1
152
+
153
+ MATMUL_PRECISION = Q.dtype.element_ty
154
+
155
+ # Make sure each split is a multiple of BLOCK_N
156
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
157
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
158
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
159
+
160
+ off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
161
+ off_zkv = off_z % ZKV
162
+ off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
163
+ off_t = tl.program_id(1).to(INDEX_DTYPE)
164
+
165
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
166
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
167
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
168
+
169
+ K = K + k_offset
170
+ V = V + v_offset
171
+
172
+ SPARSE_Z = 1
173
+ SPARSE_HQ = 1
174
+
175
+ sparse_idx_z = off_z % SPARSE_Z
176
+ sparse_idx_h = off_hkv % SPARSE_HQ
177
+
178
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
179
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
180
+
181
+ # initialize pointer to m and l
182
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
183
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
184
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
185
+
186
+ # initialize offsets
187
+ tl.device_assert(BLOCK_M % G == 0)
188
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
189
+ off_g = tl.arange(0, G) # [G]
190
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
191
+ offs_hq = offs_g + off_hkv * G
192
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
193
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
194
+ offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
195
+ offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
196
+
197
+ # Get HZ offsets for KV_NUM_BLKS and KV_IDX
198
+ stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
199
+ sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
200
+ stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
201
+ sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
202
+
203
+ # Calculate KV blocks that belong this CTA.
204
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
205
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
206
+
207
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
208
+
209
+ if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
210
+ q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
211
+ elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
212
+ q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
213
+ elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
214
+ q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
215
+ else:
216
+ q = tl.load(Q + q_offset + q_range)
217
+
218
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
219
+
220
+
221
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
222
+ # find first kv block we are loading and the number of blocks we are loading
223
+ # Offset the kv_indices tensor by the correct batch and head
224
+ kv_indices = KV_IDX + sparse_idx_hz_offset
225
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
226
+ MAX_KV_IDX = 1
227
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
228
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
229
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
230
+ # first kv block we're loading
231
+
232
+ # last valid block according to sparse mask
233
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
234
+
235
+ offs_n = tl.arange(0, BLOCK_N) + off_n
236
+
237
+ desc_k = None
238
+ desc_v = None
239
+
240
+ acc, l_i, m_i = forward_inner(
241
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
242
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
243
+ # accumulatd values
244
+ acc, l_i, m_i,
245
+ #offsets
246
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
247
+ off_n,
248
+ #block sparse data
249
+ kv_indices, kv_num_blocks,
250
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=False,
254
+ )
255
+
256
+
257
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
258
+ # We know these blocks are guaranteed to be "full", so we don't need to
259
+ # apply mask_mod to them - only score_mod
260
+ if HAS_FULL_BLOCKS:
261
+ kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
262
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
263
+ # Assign full block in a reverse order for off_t. Prioritize the last CTA.
264
+ block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
265
+ block_n_end = block_n_start + TILE_KV_MULTIPLE
266
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
267
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
268
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
269
+
270
+ # last valid block according to sparse mask
271
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
272
+
273
+ offs_n = tl.arange(0, BLOCK_N) + off_n
274
+
275
+ acc, l_i, m_i = forward_inner(
276
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
277
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
278
+ # accumulatd values
279
+ acc, l_i, m_i,
280
+ #offsets
281
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
282
+ off_n,
283
+ #block sparse data
284
+ kv_indices, kv_num_blocks,
285
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
286
+ MATMUL_PRECISION,
287
+ stride_kk, stride_kn, stride_vn, stride_vk,
288
+ IS_FULL_BLOCKS=True,
289
+ )
290
+
291
+ m_offset = off_t * stride_mt + off_z * stride_mz
292
+ l_offset = off_t * stride_lt + off_z * stride_lz
293
+
294
+ M_block_ptr = tl.make_block_ptr(
295
+ base=M + m_offset,
296
+ shape=(G, Q_LEN), # (G, M)
297
+ strides=(stride_mh, stride_mm),
298
+ offsets=(off_hkv*G, 0),
299
+ block_shape=(G, BLOCK_M_PER_HQ),
300
+ order=(1, 0)
301
+ )
302
+ L_block_ptr = tl.make_block_ptr(
303
+ base=L + l_offset,
304
+ shape=(G, Q_LEN), # (G, M)
305
+ strides=(stride_lh, stride_lm),
306
+ offsets=(off_hkv*G, 0),
307
+ block_shape=(G, BLOCK_M_PER_HQ),
308
+ order=(1, 0)
309
+ )
310
+
311
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
312
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
313
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
314
+ if SAFE_M_BOUNDARY:
315
+ tl.store(M_block_ptr, m_i)
316
+ tl.store(L_block_ptr, l_i)
317
+ else:
318
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
319
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
320
+
321
+ # -- store output
322
+ idx_z = off_z
323
+ idx_t = off_t
324
+ idx_hq = off_hkv*G + off_g[:, None, None]
325
+ idx_m = off_m[None, :, None]
326
+ idx_d = offs_vd[None, None, :]
327
+
328
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
329
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
330
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
331
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
332
+
333
+
334
+ # Utility triton funcs
335
+ @triton.jit
336
+ def get_offset_for_next_block(
337
+ loop_iter, col_indices, total_blocks,
338
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
339
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
340
+ ):
341
+ if BLOCKS_ARE_CONTIGUOUS:
342
+ return BLOCK
343
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
344
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
345
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
346
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
347
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
348
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
349
+ return offset
350
+
351
+ @triton.jit
352
+ def get_bounded_indices(indices, max_len=None):
353
+ return indices % max_len if max_len is not None else indices
354
+
355
+ @triton.jit
356
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
357
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
358
+ return tl.load(block_ptr)
359
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
360
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
361
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
362
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
363
+ else:
364
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
365
+
366
+ @triton.jit
367
+ def load_checked_2d(
368
+ ptr,
369
+ offs_m,
370
+ offs_n,
371
+ stride_m,
372
+ stride_n,
373
+ IS_DIVISIBLE_M: tl.constexpr,
374
+ IS_DIVISIBLE_N: tl.constexpr,
375
+ M_LEN: tl.constexpr,
376
+ N_LEN: tl.constexpr,
377
+ ):
378
+ # Calculate final pointer if strides are provided
379
+ if stride_m is not None and stride_n is not None:
380
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
381
+
382
+ # Handle all masking cases
383
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
384
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
385
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
386
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
387
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
388
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
389
+ else: # Both divisible
390
+ return tl.load(ptr)
391
+
392
+
393
+ # Common Imports
394
+ @triton.jit
395
+ def forward_block_mn(
396
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
397
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
398
+ # accumulated values
399
+ acc, l_i, m_i,
400
+ # Offsets
401
+ off_z, off_h, offs_m, offs_n,
402
+ # Offsets needed for TMA loads
403
+ kv_start,
404
+ kv_offset,
405
+ MATMUL_PRECISION, RCP_LN2,
406
+ # Strides for K and V
407
+ stride_kk, stride_kn, stride_vn, stride_vk,
408
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
409
+
410
+ ):
411
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
412
+ PRESCALE_QK : tl.constexpr = False
413
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
414
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
415
+ WRITE_DQ : tl.constexpr = True
416
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
417
+ OUTPUT_MAX : tl.constexpr = False
418
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
419
+ IS_DIVISIBLE : tl.constexpr = False
420
+ GQA_SHARED_HEADS : tl.constexpr = 4
421
+ HAS_FULL_BLOCKS : tl.constexpr = True
422
+ SM_SCALE : tl.constexpr = 0.08838834764831845
423
+ SPLIT_KV : tl.constexpr = 32
424
+ QK_HEAD_DIM : tl.constexpr = 128
425
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
426
+ V_HEAD_DIM : tl.constexpr = 128
427
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
428
+ SAFE_HEAD_DIM : tl.constexpr = True
429
+ BLOCK_M : tl.constexpr = 512
430
+ SAFE_M_BOUNDARY : tl.constexpr = False
431
+ SAFE_N_BOUNDARY : tl.constexpr = True
432
+ BLOCK_N : tl.constexpr = 64
433
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
434
+ USE_TMA : tl.constexpr = False
435
+ INDEX_DTYPE : tl.constexpr = tl.int32
436
+
437
+
438
+ # -- load k --
439
+ # NB reversed order to since K is transposed
440
+ kv_base_offset = kv_start + kv_offset
441
+
442
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
443
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
444
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
445
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
446
+
447
+ k = tl.trans(k)
448
+ # -- compute qk ---
449
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
450
+ if not PRESCALE_QK:
451
+ qk *= SM_SCALE
452
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
453
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
454
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
455
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
456
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
457
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
458
+
459
+ tmp0 = (qk)
460
+ post_mod_scores = tmp0
461
+
462
+
463
+ if CHECK_BLOCK_BOUNDARY:
464
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
465
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
466
+
467
+ if not IS_FULL_BLOCKS:
468
+ tmp1 = (m)
469
+ tmp2 = tl.full([1], 0, tl.int32)
470
+ tmp3 = tmp1 < tmp2
471
+ tmp4 = (n)
472
+ tmp5 = tmp4 <= tmp1
473
+ tmp6 = tmp3 & tmp5
474
+ tmp7 = tmp1 >= tmp2
475
+ tmp8 = tmp4 < tmp2
476
+ tmp9 = tmp7 & tmp8
477
+ tmp10 = tmp8 == 0
478
+ tmp11 = tmp7 & tmp10
479
+ tmp12 = tmp1 - tmp2
480
+ tmp13 = tl.full([1], 16, tl.int32)
481
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
482
+ tmp15 = tmp4 - tmp2
483
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
484
+ tmp17 = tmp14 == tmp16
485
+ tmp18 = tmp11 & tmp17
486
+ tmp19 = tmp9 | tmp18
487
+ tmp20 = tmp6 | tmp19
488
+ mask_mod_output = tmp20
489
+
490
+
491
+ if CHECK_BLOCK_BOUNDARY:
492
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
493
+ # apply mask for partially unmasked blocks
494
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
495
+
496
+ if not PRESCALE_QK:
497
+ post_mod_scores *= RCP_LN2
498
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
499
+
500
+ # -- compute scaling constant ---
501
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
502
+ if not ROWS_GUARANTEED_SAFE:
503
+ masked_out_rows = (m_ij == float("-inf"))
504
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
505
+ else:
506
+ m_ij_masked = m_ij
507
+
508
+ alpha = tl.math.exp2(m_i - m_ij_masked)
509
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
510
+
511
+ # NB: l_i update is pulled up here since it's a bit faster
512
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
513
+ # m_ij
514
+ l_i = l_i * alpha + tl.sum(p, 1)
515
+ # # -- scale and update acc --
516
+ acc = acc * alpha[:, None]
517
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
518
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
519
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
520
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
521
+
522
+ # -- update m_i
523
+ m_i = m_ij
524
+
525
+ return acc, l_i, m_i
526
+
527
+ @triton.jit
528
+ def forward_inner(
529
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
530
+ q, K, V,
531
+ desc_k, desc_v, Q_LEN, KV_LEN,
532
+ # accumulated values
533
+ acc, l_i, m_i,
534
+ # Offsets used as inputs to score_mod & mask_mod
535
+ # of size [BLOCK_M, BLOCK_N] or scalar.
536
+ off_z, off_h, offs_m, offs_n,
537
+ # Offsets needed for TMA loads
538
+ kv_start,
539
+ # blocksparse data
540
+ kv_indices, kv_num_blocks,
541
+ # start kv and end kv block
542
+ block_n_start, block_n_end,
543
+ MATMUL_PRECISION,
544
+ # Strides for K and V
545
+ stride_kk, stride_kn, stride_vn, stride_vk,
546
+ IS_FULL_BLOCKS,
547
+ ):
548
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
549
+ PRESCALE_QK : tl.constexpr = False
550
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
551
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
552
+ WRITE_DQ : tl.constexpr = True
553
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
554
+ OUTPUT_MAX : tl.constexpr = False
555
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
556
+ IS_DIVISIBLE : tl.constexpr = False
557
+ GQA_SHARED_HEADS : tl.constexpr = 4
558
+ HAS_FULL_BLOCKS : tl.constexpr = True
559
+ SM_SCALE : tl.constexpr = 0.08838834764831845
560
+ SPLIT_KV : tl.constexpr = 32
561
+ QK_HEAD_DIM : tl.constexpr = 128
562
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
563
+ V_HEAD_DIM : tl.constexpr = 128
564
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
565
+ SAFE_HEAD_DIM : tl.constexpr = True
566
+ BLOCK_M : tl.constexpr = 512
567
+ SAFE_M_BOUNDARY : tl.constexpr = False
568
+ SAFE_N_BOUNDARY : tl.constexpr = True
569
+ BLOCK_N : tl.constexpr = 64
570
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
571
+ USE_TMA : tl.constexpr = False
572
+ INDEX_DTYPE : tl.constexpr = tl.int32
573
+
574
+
575
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
576
+ RCP_LN2: tl.constexpr = 1.44269504
577
+
578
+ if PRESCALE_QK:
579
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
580
+
581
+ kv_offset = 0
582
+
583
+ # loop over k, v and update accumulator until block_n_end
584
+ for start_n in range(block_n_start, block_n_end):
585
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
586
+ if IS_DIVISIBLE:
587
+ acc, l_i, m_i = forward_block_mn(
588
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
589
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
590
+ # accumulated values
591
+ acc, l_i, m_i,
592
+ # Offsets
593
+ off_z, off_h, offs_m, offs_n,
594
+ # Offsets needed for TMA loads
595
+ kv_start,
596
+ kv_offset,
597
+ MATMUL_PRECISION, RCP_LN2,
598
+ # Strides for K and V
599
+ stride_kk, stride_kn, stride_vn, stride_vk,
600
+ IS_FULL_BLOCKS,
601
+ )
602
+ else:
603
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
604
+ # it's on par or slightly faster than only applying to the last block in fwd.
605
+ # However, we choose different strategy for bwd, where we only apply mod & mask
606
+ # to the last block because it's faster a lot.
607
+ acc, l_i, m_i = forward_block_mn(
608
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
609
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
610
+ # accumulated values
611
+ acc, l_i, m_i,
612
+ # Offsets
613
+ off_z, off_h, offs_m, offs_n,
614
+ # Offsets needed for TMA loads
615
+ kv_start,
616
+ kv_offset,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ # Strides for K and V
619
+ stride_kk, stride_kn, stride_vn, stride_vk,
620
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
621
+ )
622
+
623
+
624
+
625
+ offset = get_offset_for_next_block(
626
+ start_n, kv_indices, kv_num_blocks,
627
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
628
+ )
629
+
630
+ offs_n = offs_n + offset
631
+ kv_offset += offset
632
+
633
+
634
+ return acc, l_i, m_i
635
+ ''', device_str='cuda')
636
+
637
+
638
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7i/c7ijsdt7wst5xe64qslxvdevuoxlscozrh6zigwqavztuz3rptdj.py
639
+ # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
640
+ # Source node to ATen node mapping:
641
+ # flex_attention => flex_attention
642
+ # lse_scaled => mul_9
643
+ # Graph fragment:
644
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
645
+ # %buf4 : Tensor = PlaceHolder[target=buf4]
646
+ # %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
647
+ # %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
648
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
649
+ # %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
650
+ # return %buf5,%buf7,%mul_9
651
+ triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', '''
652
+ import triton
653
+ import triton.language as tl
654
+
655
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
656
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
657
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
658
+ triton_helpers.set_driver_to_gpu()
659
+
660
+ @triton_heuristics.persistent_reduction(
661
+ size_hints={'x': 4096, 'r0_': 32},
662
+ reduction_hint=ReductionHint.DEFAULT,
663
+ filename=__file__,
664
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
665
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
666
+ )
667
+ @triton.jit
668
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
669
+ r0_numel = 32
670
+ R0_BLOCK: tl.constexpr = 32
671
+ rnumel = r0_numel
672
+ RBLOCK: tl.constexpr = R0_BLOCK
673
+ xoffset = tl.program_id(0) * XBLOCK
674
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
675
+ xmask = xindex < xnumel
676
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
677
+ r0_offset = 0
678
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
679
+ roffset = r0_offset
680
+ rindex = r0_index
681
+ r0_1 = r0_index
682
+ x0 = xindex
683
+ x2 = (xindex % ks0)
684
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
685
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
686
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
687
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
688
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
689
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
690
+ tmp6 = float("-inf")
691
+ tmp7 = tmp4 == tmp6
692
+ tmp8 = tmp0 - tmp4
693
+ tmp9 = 0.0
694
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
695
+ tmp11 = libdevice.exp2(tmp10)
696
+ tmp12 = tmp5 * tmp11
697
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
698
+ tmp15 = tl.where(xmask, tmp13, 0)
699
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
700
+ tmp17 = 1.0
701
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
702
+ tmp19 = libdevice.log2(tmp18)
703
+ tmp20 = tmp19 + tmp4
704
+ tmp21 = 0.6931471805599453
705
+ tmp22 = tmp20 * tmp21
706
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
707
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
708
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
709
+ ''', device_str='cuda')
710
+
711
+
712
+ # kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xf/cxfq6ic3efg6yukl55sige4jtfyzy5dlspo6ipsr5mew7q7i32dg.py
713
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
714
+ # Source node to ATen node mapping:
715
+ # flex_attention => flex_attention, getitem
716
+ # Graph fragment:
717
+ # %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf2]
718
+ # %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
719
+ # %buf3 : Tensor = PlaceHolder[target=buf3]
720
+ # %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf8]
721
+ # %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
722
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
723
+ # %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {})
724
+ # return %buf8,%getitem
725
+ triton_per_fused_2 = async_compile.triton('triton_per_fused_2', '''
726
+ import triton
727
+ import triton.language as tl
728
+
729
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
730
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
731
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
732
+ triton_helpers.set_driver_to_gpu()
733
+
734
+ @triton_heuristics.persistent_reduction(
735
+ size_hints={'x': 524288, 'r0_': 32},
736
+ reduction_hint=ReductionHint.DEFAULT,
737
+ filename=__file__,
738
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
739
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
740
+ )
741
+ @triton.jit
742
+ def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
743
+ r0_numel = 32
744
+ R0_BLOCK: tl.constexpr = 32
745
+ rnumel = r0_numel
746
+ RBLOCK: tl.constexpr = R0_BLOCK
747
+ xoffset = tl.program_id(0) * XBLOCK
748
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
749
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
750
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
751
+ r0_offset = 0
752
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
753
+ roffset = r0_offset
754
+ rindex = r0_index
755
+ r0_2 = r0_index
756
+ x5 = xindex
757
+ x1 = xindex // 128
758
+ x0 = (xindex % 128)
759
+ x3 = ((xindex // 128) % ks0)
760
+ x4 = xindex // ks1
761
+ tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
762
+ tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
763
+ tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
764
+ tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
765
+ tmp2 = float("-inf")
766
+ tmp3 = tmp1 == tmp2
767
+ tmp5 = tmp4 - tmp1
768
+ tmp6 = 0.0
769
+ tmp7 = tl.where(tmp3, tmp6, tmp5)
770
+ tmp8 = libdevice.exp2(tmp7)
771
+ tmp9 = tmp0 * tmp8
772
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
773
+ tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
774
+ tmp14 = 1.0
775
+ tmp15 = tl.where(tmp3, tmp14, tmp13)
776
+ tmp16 = (tmp12 / tmp15)
777
+ tmp17 = tmp16.to(tl.float32)
778
+ tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
779
+ ''', device_str='cuda')
780
+
781
+
782
+ async_compile.wait(globals())
783
+ del async_compile
784
+
785
+ class Runner:
786
+ def __init__(self, partitions):
787
+ self.partitions = partitions
788
+
789
+ def recursively_apply_fns(self, fns):
790
+ new_callables = []
791
+ for fn, c in zip(fns, self.partitions):
792
+ new_callables.append(fn(c))
793
+ self.partitions = new_callables
794
+
795
+ def call(self, args):
796
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
797
+ args.clear()
798
+ s50 = arg0_1
799
+ s0 = arg2_1
800
+ s43 = arg4_1
801
+ s37 = arg7_1
802
+ s71 = arg8_1
803
+ assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
804
+ assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
805
+ assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
806
+ assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
807
+ assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
808
+ assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
809
+ assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
810
+ assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
811
+ assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
812
+ assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
813
+ assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
814
+ with torch.cuda._DeviceGuard(7):
815
+ torch.cuda.set_device(7)
816
+ buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
817
+ buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
818
+ buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32)
819
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
820
+ stream7 = get_raw_stream(7)
821
+ triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream7)
822
+ del arg10_1
823
+ del arg11_1
824
+ del arg1_1
825
+ del arg3_1
826
+ del arg5_1
827
+ del arg6_1
828
+ del arg9_1
829
+ buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32)
830
+ buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
831
+ buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
832
+ # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
833
+ triton_per_fused_mul_1_xnumel = 32*s37
834
+ stream7 = get_raw_stream(7)
835
+ triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream7)
836
+ del buf1
837
+ ps0 = 128*s37
838
+ buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
839
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
840
+ triton_per_fused_2_xnumel = 4096*s37
841
+ stream7 = get_raw_stream(7)
842
+ triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream7)
843
+ del buf0
844
+ del buf2
845
+ del buf5
846
+ del buf7
847
+ return (buf9, buf10, )
848
+
849
+ runner = Runner(partitions=[])
850
+ call = runner.call
851
+ recursively_apply_fns = runner.recursively_apply_fns
852
+
853
+
854
+ def benchmark_compiled_module(times=10, repeat=10):
855
+ from torch._dynamo.testing import rand_strided
856
+ from torch._inductor.utils import print_performance
857
+ arg0_1 = 96
858
+ arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
859
+ arg2_1 = 96
860
+ arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
861
+ arg4_1 = 96
862
+ arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
863
+ arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
864
+ arg7_1 = 96
865
+ arg8_1 = 96
866
+ arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
867
+ arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
868
+ arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
869
+ arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
870
+ arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
871
+ arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
872
+ arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
873
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
874
+ return print_performance(fn, times=times, repeat=repeat)
875
+
876
+
877
+ if __name__ == "__main__":
878
+ from torch._inductor.wrapper_benchmark import compiled_module_main
879
+ compiled_module_main('None', benchmark_compiled_module)
progress/SpecForge/cache/compiled_kernels/6g/826e9651ad1f65a7d666097ac7518bb4b4d3dee1984132523b860dd02b66fff1.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"}
progress/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ x2 = (xindex % ks0)
34
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
35
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
36
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
39
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
40
+ tmp6 = float("-inf")
41
+ tmp7 = tmp4 == tmp6
42
+ tmp8 = tmp0 - tmp4
43
+ tmp9 = 0.0
44
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
45
+ tmp11 = libdevice.exp2(tmp10)
46
+ tmp12 = tmp5 * tmp11
47
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
48
+ tmp15 = tl.where(xmask, tmp13, 0)
49
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
50
+ tmp17 = 1.0
51
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
52
+ tmp19 = libdevice.log2(tmp18)
53
+ tmp20 = tmp19 + tmp4
54
+ tmp21 = 0.6931471805599453
55
+ tmp22 = tmp20 * tmp21
56
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
57
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
58
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
progress/SpecForge/cache/compiled_kernels/6o/c6ovzyfo6vkdwwzou6dtdvw7qjf65ifmzpcoltl2nx2xuluryjcy.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 128},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 128
20
+ R0_BLOCK: tl.constexpr = 128
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
36
+ tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
37
+ tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
38
+ tmp2 = tmp0 * tmp1
39
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5 = tl.where(xmask, tmp3, 0)
41
+ tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
42
+ tmp7 = tmp6.to(tl.float32)
43
+ tmp9 = 0.6931471805599453
44
+ tmp10 = tmp8 * tmp9
45
+ tmp11 = 1.4426950408889634
46
+ tmp12 = tmp10 * tmp11
47
+ tmp13 = tmp7 - tmp12
48
+ tl.store(out_ptr1 + (x3), tmp13, xmask)
progress/SpecForge/cache/compiled_kernels/6o/df004f0eefe2693a59f2bae06581f78ab07b5ce2ec28936911ef13f1152e2ec9.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"}
progress/SpecForge/cache/compiled_kernels/6u/c6ulsdn73forgosxqs5bes2cerczsehypg7jodd4snit3gcqp6el.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 32768},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 31232
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
25
+ tmp1 = 0.6931471805599453
26
+ tmp2 = tmp0 * tmp1
27
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 1
130
+ stride_kv_idx_h = 1
131
+ stride_kv_idx_m = 1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'ieee'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/6u/e5329724392dcdd68d88f082f57e8929de539c9aa187c3a314edefdc595437d5.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA"}
progress/SpecForge/cache/compiled_kernels/7a/aee791ee3934869dfa55caee4270f116dc979737c2bbcce40af5d5394ccc9ac8.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
progress/SpecForge/cache/compiled_kernels/7a/c7a2brsshxp6zz4foe62t5ivwbd2dwr6ytjbhxp22vq2evdotx5z.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4096},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x2 = xindex
23
+ x0 = (xindex % ks0)
24
+ x1 = triton_helpers.div_floor_integer(xindex, ks0)
25
+ tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
26
+ tmp1 = 0.6931471805599453
27
+ tmp2 = tmp0 * tmp1
28
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
progress/SpecForge/cache/compiled_kernels/7a/c7adkdqab5cvqxxnwnn5au23gorh2eg33cfxneh7bzb7untnuvpw.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/7i/c7ijsdt7wst5xe64qslxvdevuoxlscozrh6zigwqavztuz3rptdj.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ x2 = (xindex % ks0)
34
+ x3 = triton_helpers.div_floor_integer(xindex, ks0)
35
+ tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
36
+ tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ tmp3 = tl.where(xmask, tmp1, float("-inf"))
39
+ tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
40
+ tmp6 = float("-inf")
41
+ tmp7 = tmp4 == tmp6
42
+ tmp8 = tmp0 - tmp4
43
+ tmp9 = 0.0
44
+ tmp10 = tl.where(tmp7, tmp9, tmp8)
45
+ tmp11 = libdevice.exp2(tmp10)
46
+ tmp12 = tmp5 * tmp11
47
+ tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
48
+ tmp15 = tl.where(xmask, tmp13, 0)
49
+ tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
50
+ tmp17 = 1.0
51
+ tmp18 = tl.where(tmp7, tmp17, tmp16)
52
+ tmp19 = libdevice.log2(tmp18)
53
+ tmp20 = tmp19 + tmp4
54
+ tmp21 = 0.6931471805599453
55
+ tmp22 = tmp20 * tmp21
56
+ tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
57
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
58
+ tl.store(out_ptr1 + (x0), tmp16, xmask)
progress/SpecForge/cache/compiled_kernels/7i/e1454135615b9b6420e5ef4fe0804f1b1346f398b5afb34dbe8085f0e900c8aa.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "RHE7JXFOJQBB2ESJBNT5OZ3PCTRJ7WSJRB7A2GLRM73N3EI7TWDQ"}
progress/SpecForge/cache/compiled_kernels/7n/c7n4jk5r4lsbq62vtrxzouvawlecnbfhy3owedw4ewuid7d56bjs.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=2,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ GQA_SHARED_HEADS : tl.constexpr = 4
28
+ HAS_FULL_BLOCKS : tl.constexpr = True
29
+ SM_SCALE : tl.constexpr = 0.08838834764831845
30
+ SPLIT_KV : tl.constexpr = 32
31
+ QK_HEAD_DIM : tl.constexpr = 128
32
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
33
+ V_HEAD_DIM : tl.constexpr = 128
34
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
35
+ SAFE_HEAD_DIM : tl.constexpr = True
36
+ BLOCK_M : tl.constexpr = 512
37
+ SAFE_M_BOUNDARY : tl.constexpr = False
38
+ SAFE_N_BOUNDARY : tl.constexpr = True
39
+ BLOCK_N : tl.constexpr = 64
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ USE_TMA : tl.constexpr = False
42
+ INDEX_DTYPE : tl.constexpr = tl.int32
43
+ Q = arg_Q
44
+ K = arg_K
45
+ V = arg_V
46
+ M = arg_M
47
+ L = arg_L
48
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
49
+ KV_IDX = arg_KV_IDX
50
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
51
+ FULL_KV_IDX = arg_FULL_KV_IDX
52
+
53
+ # Sub notation for this kernel:
54
+ # Q: Query, K: Key, V: Value
55
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
56
+ # M: Number of queries, N: Number of keys/values
57
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
58
+ # V_HEAD_DIM: The dimension of the value embeddings
59
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
60
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
61
+ # (Modifiable) Config options:
62
+ # SPLIT_KV: number of blocks K & V are split into
63
+ # TILE_KV: length of each local KV split
64
+ # BLOCK_M: block size that Q is padded along seqlen dim.
65
+ # BLOCK_N: block size of K & V along N dimension.
66
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
67
+ #
68
+ # change of base out of the loop
69
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
70
+ # is not masked out? If so, we can skip an extra safety check
71
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
72
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
73
+
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
75
+ #
76
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
77
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
78
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
79
+ #
80
+ #
81
+ # Output: ACC output accumulated across local KV split.
82
+
83
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
84
+
85
+ # Define Q Strides
86
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
87
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
88
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
89
+ stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
90
+ stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
91
+
92
+
93
+ Z = 1
94
+ ZKV = 1
95
+ HKV = 8
96
+ G: tl.constexpr = GQA_SHARED_HEADS
97
+ HQ = HKV * G
98
+ Q_LEN = ks0
99
+ KV_LEN = ks1
100
+
101
+ MATMUL_PRECISION = Q.dtype.element_ty
102
+
103
+ # Make sure each split is a multiple of BLOCK_N
104
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
105
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
106
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
107
+
108
+ off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
109
+ off_zkv = off_z % ZKV
110
+ off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
111
+ off_t = tl.program_id(1).to(INDEX_DTYPE)
112
+
113
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
114
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
115
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
116
+
117
+ K = K + k_offset
118
+ V = V + v_offset
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_z % SPARSE_Z
124
+ sparse_idx_h = off_hkv % SPARSE_HQ
125
+
126
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
127
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
128
+
129
+ # initialize pointer to m and l
130
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
131
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
132
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
133
+
134
+ # initialize offsets
135
+ tl.device_assert(BLOCK_M % G == 0)
136
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
137
+ off_g = tl.arange(0, G) # [G]
138
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
139
+ offs_hq = offs_g + off_hkv * G
140
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
141
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
142
+ offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
143
+ offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
144
+
145
+ # Get HZ offsets for KV_NUM_BLKS and KV_IDX
146
+ stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
147
+ sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
148
+ stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
149
+ sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
150
+
151
+ # Calculate KV blocks that belong this CTA.
152
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
153
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
154
+
155
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
156
+
157
+ if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
158
+ q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
159
+ elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
160
+ q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
161
+ elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
162
+ q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
163
+ else:
164
+ q = tl.load(Q + q_offset + q_range)
165
+
166
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
167
+
168
+
169
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
170
+ # find first kv block we are loading and the number of blocks we are loading
171
+ # Offset the kv_indices tensor by the correct batch and head
172
+ kv_indices = KV_IDX + sparse_idx_hz_offset
173
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
174
+ MAX_KV_IDX = 1
175
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
176
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
177
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
178
+ # first kv block we're loading
179
+
180
+ # last valid block according to sparse mask
181
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
182
+
183
+ offs_n = tl.arange(0, BLOCK_N) + off_n
184
+
185
+ desc_k = None
186
+ desc_v = None
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
190
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
191
+ # accumulatd values
192
+ acc, l_i, m_i,
193
+ #offsets
194
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
195
+ off_n,
196
+ #block sparse data
197
+ kv_indices, kv_num_blocks,
198
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
199
+ MATMUL_PRECISION,
200
+ stride_kk, stride_kn, stride_vn, stride_vk,
201
+ IS_FULL_BLOCKS=False,
202
+ )
203
+
204
+
205
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
206
+ # We know these blocks are guaranteed to be "full", so we don't need to
207
+ # apply mask_mod to them - only score_mod
208
+ if HAS_FULL_BLOCKS:
209
+ kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
210
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
211
+ # Assign full block in a reverse order for off_t. Prioritize the last CTA.
212
+ block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
213
+ block_n_end = block_n_start + TILE_KV_MULTIPLE
214
+ indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
215
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
216
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
217
+
218
+ # last valid block according to sparse mask
219
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
220
+
221
+ offs_n = tl.arange(0, BLOCK_N) + off_n
222
+
223
+ acc, l_i, m_i = forward_inner(
224
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
225
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
226
+ # accumulatd values
227
+ acc, l_i, m_i,
228
+ #offsets
229
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
230
+ off_n,
231
+ #block sparse data
232
+ kv_indices, kv_num_blocks,
233
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
234
+ MATMUL_PRECISION,
235
+ stride_kk, stride_kn, stride_vn, stride_vk,
236
+ IS_FULL_BLOCKS=True,
237
+ )
238
+
239
+ m_offset = off_t * stride_mt + off_z * stride_mz
240
+ l_offset = off_t * stride_lt + off_z * stride_lz
241
+
242
+ M_block_ptr = tl.make_block_ptr(
243
+ base=M + m_offset,
244
+ shape=(G, Q_LEN), # (G, M)
245
+ strides=(stride_mh, stride_mm),
246
+ offsets=(off_hkv*G, 0),
247
+ block_shape=(G, BLOCK_M_PER_HQ),
248
+ order=(1, 0)
249
+ )
250
+ L_block_ptr = tl.make_block_ptr(
251
+ base=L + l_offset,
252
+ shape=(G, Q_LEN), # (G, M)
253
+ strides=(stride_lh, stride_lm),
254
+ offsets=(off_hkv*G, 0),
255
+ block_shape=(G, BLOCK_M_PER_HQ),
256
+ order=(1, 0)
257
+ )
258
+
259
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
260
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
261
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
262
+ if SAFE_M_BOUNDARY:
263
+ tl.store(M_block_ptr, m_i)
264
+ tl.store(L_block_ptr, l_i)
265
+ else:
266
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
267
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
268
+
269
+ # -- store output
270
+ idx_z = off_z
271
+ idx_t = off_t
272
+ idx_hq = off_hkv*G + off_g[:, None, None]
273
+ idx_m = off_m[None, :, None]
274
+ idx_d = offs_vd[None, None, :]
275
+
276
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
277
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
278
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
279
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
280
+
281
+
282
+ # Utility triton funcs
283
+ @triton.jit
284
+ def get_offset_for_next_block(
285
+ loop_iter, col_indices, total_blocks,
286
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
287
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
288
+ ):
289
+ if BLOCKS_ARE_CONTIGUOUS:
290
+ return BLOCK
291
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
292
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
293
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
294
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
295
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
296
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
297
+ return offset
298
+
299
+ @triton.jit
300
+ def get_bounded_indices(indices, max_len=None):
301
+ return indices % max_len if max_len is not None else indices
302
+
303
+ @triton.jit
304
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
305
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
306
+ return tl.load(block_ptr)
307
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
308
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
309
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
310
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
311
+ else:
312
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
313
+
314
+ @triton.jit
315
+ def load_checked_2d(
316
+ ptr,
317
+ offs_m,
318
+ offs_n,
319
+ stride_m,
320
+ stride_n,
321
+ IS_DIVISIBLE_M: tl.constexpr,
322
+ IS_DIVISIBLE_N: tl.constexpr,
323
+ M_LEN: tl.constexpr,
324
+ N_LEN: tl.constexpr,
325
+ ):
326
+ # Calculate final pointer if strides are provided
327
+ if stride_m is not None and stride_n is not None:
328
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
329
+
330
+ # Handle all masking cases
331
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
332
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
333
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
334
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
335
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
336
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
337
+ else: # Both divisible
338
+ return tl.load(ptr)
339
+
340
+
341
+ # Common Imports
342
+ @triton.jit
343
+ def forward_block_mn(
344
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
345
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
346
+ # accumulated values
347
+ acc, l_i, m_i,
348
+ # Offsets
349
+ off_z, off_h, offs_m, offs_n,
350
+ # Offsets needed for TMA loads
351
+ kv_start,
352
+ kv_offset,
353
+ MATMUL_PRECISION, RCP_LN2,
354
+ # Strides for K and V
355
+ stride_kk, stride_kn, stride_vn, stride_vk,
356
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
357
+
358
+ ):
359
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
360
+ PRESCALE_QK : tl.constexpr = False
361
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
362
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
363
+ WRITE_DQ : tl.constexpr = True
364
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
365
+ OUTPUT_MAX : tl.constexpr = False
366
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
367
+ IS_DIVISIBLE : tl.constexpr = False
368
+ GQA_SHARED_HEADS : tl.constexpr = 4
369
+ HAS_FULL_BLOCKS : tl.constexpr = True
370
+ SM_SCALE : tl.constexpr = 0.08838834764831845
371
+ SPLIT_KV : tl.constexpr = 32
372
+ QK_HEAD_DIM : tl.constexpr = 128
373
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
374
+ V_HEAD_DIM : tl.constexpr = 128
375
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
376
+ SAFE_HEAD_DIM : tl.constexpr = True
377
+ BLOCK_M : tl.constexpr = 512
378
+ SAFE_M_BOUNDARY : tl.constexpr = False
379
+ SAFE_N_BOUNDARY : tl.constexpr = True
380
+ BLOCK_N : tl.constexpr = 64
381
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
382
+ USE_TMA : tl.constexpr = False
383
+ INDEX_DTYPE : tl.constexpr = tl.int32
384
+
385
+
386
+ # -- load k --
387
+ # NB reversed order to since K is transposed
388
+ kv_base_offset = kv_start + kv_offset
389
+
390
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
391
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
392
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
393
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
394
+
395
+ k = tl.trans(k)
396
+ # -- compute qk ---
397
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
398
+ if not PRESCALE_QK:
399
+ qk *= SM_SCALE
400
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
401
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
402
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
403
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
404
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
405
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
406
+
407
+ tmp0 = (qk)
408
+ post_mod_scores = tmp0
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
413
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
414
+
415
+ if not IS_FULL_BLOCKS:
416
+ tmp1 = (m)
417
+ tmp2 = tl.full([1], 0, tl.int32)
418
+ tmp3 = tmp1 < tmp2
419
+ tmp4 = (n)
420
+ tmp5 = tmp4 <= tmp1
421
+ tmp6 = tmp3 & tmp5
422
+ tmp7 = tmp1 >= tmp2
423
+ tmp8 = tmp4 < tmp2
424
+ tmp9 = tmp7 & tmp8
425
+ tmp10 = tmp8 == 0
426
+ tmp11 = tmp7 & tmp10
427
+ tmp12 = tmp1 - tmp2
428
+ tmp13 = tl.full([1], 16, tl.int32)
429
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
430
+ tmp15 = tmp4 - tmp2
431
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
432
+ tmp17 = tmp14 == tmp16
433
+ tmp18 = tmp11 & tmp17
434
+ tmp19 = tmp9 | tmp18
435
+ tmp20 = tmp6 | tmp19
436
+ mask_mod_output = tmp20
437
+
438
+
439
+ if CHECK_BLOCK_BOUNDARY:
440
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
441
+ # apply mask for partially unmasked blocks
442
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
443
+
444
+ if not PRESCALE_QK:
445
+ post_mod_scores *= RCP_LN2
446
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
447
+
448
+ # -- compute scaling constant ---
449
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
450
+ if not ROWS_GUARANTEED_SAFE:
451
+ masked_out_rows = (m_ij == float("-inf"))
452
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
453
+ else:
454
+ m_ij_masked = m_ij
455
+
456
+ alpha = tl.math.exp2(m_i - m_ij_masked)
457
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
458
+
459
+ # NB: l_i update is pulled up here since it's a bit faster
460
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
461
+ # m_ij
462
+ l_i = l_i * alpha + tl.sum(p, 1)
463
+ # # -- scale and update acc --
464
+ acc = acc * alpha[:, None]
465
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
466
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
467
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
468
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
469
+
470
+ # -- update m_i
471
+ m_i = m_ij
472
+
473
+ return acc, l_i, m_i
474
+
475
+ @triton.jit
476
+ def forward_inner(
477
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
478
+ q, K, V,
479
+ desc_k, desc_v, Q_LEN, KV_LEN,
480
+ # accumulated values
481
+ acc, l_i, m_i,
482
+ # Offsets used as inputs to score_mod & mask_mod
483
+ # of size [BLOCK_M, BLOCK_N] or scalar.
484
+ off_z, off_h, offs_m, offs_n,
485
+ # Offsets needed for TMA loads
486
+ kv_start,
487
+ # blocksparse data
488
+ kv_indices, kv_num_blocks,
489
+ # start kv and end kv block
490
+ block_n_start, block_n_end,
491
+ MATMUL_PRECISION,
492
+ # Strides for K and V
493
+ stride_kk, stride_kn, stride_vn, stride_vk,
494
+ IS_FULL_BLOCKS,
495
+ ):
496
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
497
+ PRESCALE_QK : tl.constexpr = False
498
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
499
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
500
+ WRITE_DQ : tl.constexpr = True
501
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
502
+ OUTPUT_MAX : tl.constexpr = False
503
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
504
+ IS_DIVISIBLE : tl.constexpr = False
505
+ GQA_SHARED_HEADS : tl.constexpr = 4
506
+ HAS_FULL_BLOCKS : tl.constexpr = True
507
+ SM_SCALE : tl.constexpr = 0.08838834764831845
508
+ SPLIT_KV : tl.constexpr = 32
509
+ QK_HEAD_DIM : tl.constexpr = 128
510
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
511
+ V_HEAD_DIM : tl.constexpr = 128
512
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
513
+ SAFE_HEAD_DIM : tl.constexpr = True
514
+ BLOCK_M : tl.constexpr = 512
515
+ SAFE_M_BOUNDARY : tl.constexpr = False
516
+ SAFE_N_BOUNDARY : tl.constexpr = True
517
+ BLOCK_N : tl.constexpr = 64
518
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
519
+ USE_TMA : tl.constexpr = False
520
+ INDEX_DTYPE : tl.constexpr = tl.int32
521
+
522
+
523
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
524
+ RCP_LN2: tl.constexpr = 1.44269504
525
+
526
+ if PRESCALE_QK:
527
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
528
+
529
+ kv_offset = 0
530
+
531
+ # loop over k, v and update accumulator until block_n_end
532
+ for start_n in range(block_n_start, block_n_end):
533
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
534
+ if IS_DIVISIBLE:
535
+ acc, l_i, m_i = forward_block_mn(
536
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
537
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
538
+ # accumulated values
539
+ acc, l_i, m_i,
540
+ # Offsets
541
+ off_z, off_h, offs_m, offs_n,
542
+ # Offsets needed for TMA loads
543
+ kv_start,
544
+ kv_offset,
545
+ MATMUL_PRECISION, RCP_LN2,
546
+ # Strides for K and V
547
+ stride_kk, stride_kn, stride_vn, stride_vk,
548
+ IS_FULL_BLOCKS,
549
+ )
550
+ else:
551
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
552
+ # it's on par or slightly faster than only applying to the last block in fwd.
553
+ # However, we choose different strategy for bwd, where we only apply mod & mask
554
+ # to the last block because it's faster a lot.
555
+ acc, l_i, m_i = forward_block_mn(
556
+ arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
557
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
558
+ # accumulated values
559
+ acc, l_i, m_i,
560
+ # Offsets
561
+ off_z, off_h, offs_m, offs_n,
562
+ # Offsets needed for TMA loads
563
+ kv_start,
564
+ kv_offset,
565
+ MATMUL_PRECISION, RCP_LN2,
566
+ # Strides for K and V
567
+ stride_kk, stride_kn, stride_vn, stride_vk,
568
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
569
+ )
570
+
571
+
572
+
573
+ offset = get_offset_for_next_block(
574
+ start_n, kv_indices, kv_num_blocks,
575
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
576
+ )
577
+
578
+ offs_n = offs_n + offset
579
+ kv_offset += offset
580
+
581
+
582
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/a3/ca3omjumwqpxxjrgphxuxva3yanssfkbnvrp3buqomyudb2eg4nc.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831845
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
88
+
89
+ ZQ = 1
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 1
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 1
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 1
130
+ stride_kv_idx_h = 1
131
+ stride_kv_idx_m = 1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831845
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = (m)
371
+ tmp2 = tl.full([1], 0, tl.int32)
372
+ tmp3 = tmp1 < tmp2
373
+ tmp4 = (n)
374
+ tmp5 = tmp4 <= tmp1
375
+ tmp6 = tmp3 & tmp5
376
+ tmp7 = tmp1 >= tmp2
377
+ tmp8 = tmp4 < tmp2
378
+ tmp9 = tmp7 & tmp8
379
+ tmp10 = tmp8 == 0
380
+ tmp11 = tmp7 & tmp10
381
+ tmp12 = tmp1 - tmp2
382
+ tmp13 = tl.full([1], 16, tl.int32)
383
+ tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
384
+ tmp15 = tmp4 - tmp2
385
+ tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
386
+ tmp17 = tmp14 == tmp16
387
+ tmp18 = tmp11 & tmp17
388
+ tmp19 = tmp9 | tmp18
389
+ tmp20 = tmp6 | tmp19
390
+ mask_mod_output = tmp20
391
+
392
+
393
+ if CHECK_BLOCK_BOUNDARY:
394
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
395
+ # apply mask for partially unmasked blocks
396
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
397
+
398
+ if not PRESCALE_QK:
399
+ post_mod_scores *= RCP_LN2
400
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401
+
402
+ # -- compute scaling constant ---
403
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
404
+ if not ROWS_GUARANTEED_SAFE:
405
+ masked_out_rows = (m_ij == float("-inf"))
406
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
407
+ else:
408
+ m_ij_masked = m_ij
409
+
410
+ alpha = tl.math.exp2(m_i - m_ij_masked)
411
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
412
+
413
+ # NB: l_i update is pulled up here since it's a bit faster
414
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
415
+ # m_ij
416
+ l_i = l_i * alpha + tl.sum(p, 1)
417
+ # # -- scale and update acc --
418
+ acc = acc * alpha[:, None]
419
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
420
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
421
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
422
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
423
+
424
+ # -- update m_i
425
+ m_i = m_ij
426
+
427
+ return acc, l_i, m_i
428
+
429
+ @triton.jit
430
+ def forward_inner(
431
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
432
+ q, K, V,
433
+ desc_k, desc_v, Q_LEN, KV_LEN,
434
+ # accumulated values
435
+ acc, l_i, m_i,
436
+ # Offsets used as inputs to score_mod & mask_mod
437
+ # of size [BLOCK_M, BLOCK_N] or scalar.
438
+ off_z, off_h, offs_m, offs_n,
439
+ # Offsets needed for TMA loads
440
+ kv_start,
441
+ # blocksparse data
442
+ kv_indices, kv_num_blocks,
443
+ # start kv and end kv block
444
+ block_n_start, block_n_end,
445
+ MATMUL_PRECISION,
446
+ # Strides for K and V
447
+ stride_kk, stride_kn, stride_vn, stride_vk,
448
+ IS_FULL_BLOCKS,
449
+ ):
450
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
451
+ PRESCALE_QK : tl.constexpr = False
452
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
453
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
454
+ WRITE_DQ : tl.constexpr = True
455
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
456
+ OUTPUT_MAX : tl.constexpr = False
457
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
458
+ IS_DIVISIBLE : tl.constexpr = False
459
+ SM_SCALE : tl.constexpr = 0.08838834764831845
460
+ GQA_SHARED_HEADS : tl.constexpr = 4
461
+ HAS_FULL_BLOCKS : tl.constexpr = True
462
+ QK_HEAD_DIM : tl.constexpr = 128
463
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
464
+ V_HEAD_DIM : tl.constexpr = 128
465
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
466
+ SAFE_HEAD_DIM : tl.constexpr = True
467
+ USE_TMA : tl.constexpr = False
468
+ BLOCK_M : tl.constexpr = 128
469
+ BLOCK_N : tl.constexpr = 64
470
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
471
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
472
+ INDEX_DTYPE : tl.constexpr = tl.int32
473
+
474
+
475
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
476
+ RCP_LN2: tl.constexpr = 1.44269504
477
+
478
+ if PRESCALE_QK:
479
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
480
+
481
+ kv_offset = 0
482
+
483
+ # loop over k, v and update accumulator until block_n_end
484
+ for start_n in range(block_n_start, block_n_end):
485
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
486
+ if IS_DIVISIBLE:
487
+ acc, l_i, m_i = forward_block_mn(
488
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
489
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
490
+ # accumulated values
491
+ acc, l_i, m_i,
492
+ # Offsets
493
+ off_z, off_h, offs_m, offs_n,
494
+ # Offsets needed for TMA loads
495
+ kv_start,
496
+ kv_offset,
497
+ MATMUL_PRECISION, RCP_LN2,
498
+ # Strides for K and V
499
+ stride_kk, stride_kn, stride_vn, stride_vk,
500
+ IS_FULL_BLOCKS,
501
+ )
502
+ else:
503
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
504
+ # it's on par or slightly faster than only applying to the last block in fwd.
505
+ # However, we choose different strategy for bwd, where we only apply mod & mask
506
+ # to the last block because it's faster a lot.
507
+ acc, l_i, m_i = forward_block_mn(
508
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
509
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
510
+ # accumulated values
511
+ acc, l_i, m_i,
512
+ # Offsets
513
+ off_z, off_h, offs_m, offs_n,
514
+ # Offsets needed for TMA loads
515
+ kv_start,
516
+ kv_offset,
517
+ MATMUL_PRECISION, RCP_LN2,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
521
+ )
522
+
523
+
524
+
525
+ offset = get_offset_for_next_block(
526
+ start_n, kv_indices, kv_num_blocks,
527
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
528
+ )
529
+
530
+ offs_n = offs_n + offset
531
+ kv_offset += offset
532
+
533
+
534
+ return acc, l_i, m_i
progress/SpecForge/cache/compiled_kernels/ad/4a5ce6c582fc1ef37d4cf3003d603da533264d4c59e6e0cb171d0b7490f32260.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
progress/SpecForge/cache/compiled_kernels/ad/b7e210dbfa93430d766b2fdeddfba773a52faf0eca21a68632a8853e9c3ecaf4.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}