Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- progress/SpecForge/cache/compiled_kernels/22/2cc448de6eeb5f6db19c2adf4fa08d257ac050432532b15d4b1b5f447657bb74.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/22/c225p5q54jhc2rfoccuzlgejscvq2in5jzxlzcilu44cplhbfreo.py +28 -0
- progress/SpecForge/cache/compiled_kernels/22/c22yrjhqxirm4hkhfziohi6cfktos4modosbdflw25tw7a5d5gy7.py +799 -0
- progress/SpecForge/cache/compiled_kernels/2h/c2hwir33itr7umd7f5wx6cpaiwom2wrrbqi5cyznolnljriqz7pk.py +1028 -0
- progress/SpecForge/cache/compiled_kernels/2m/50bb8a1cf8ca03a72155f9f84fae7c10cfe881f9f7d97e1e3f94ec85776c3639.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/2m/c2mzrxons6jvrj3mv77db5xyv5m4z73mego5v77sl66o6wuh5dbk.py +28 -0
- progress/SpecForge/cache/compiled_kernels/2p/c2pjjgigeh2ro4r74dzlvlf7os5rhnmyche5rzwawor6zxb6rvk2.py +1018 -0
- progress/SpecForge/cache/compiled_kernels/2x/c2xecscuz5jhvznv7jn4k545b7kcexuko5lz3em6woeo7u2ftonz.py +58 -0
- progress/SpecForge/cache/compiled_kernels/2x/c2ximikyisa7xxnki36flzcsdr4ziwruq7ujf3zymsuxon5pqv57.py +707 -0
- progress/SpecForge/cache/compiled_kernels/2x/dde0479cca0d878e6e0800ec13f7c80962354e837542bfc5f11f7b49306d323e.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/3h/6ee97c795357f97e7127237e15db9bd5fb14510b837eeb5094115cfaa1802d32.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/3h/c3h3fb5vqykgr7s3powfrnsc5alooplbijdgjizqo3xq5psavrvz.py +28 -0
- progress/SpecForge/cache/compiled_kernels/3h/c3hro2ygwh2ixqhmbrrdsjq6biaehv6lm5cbeo6yhlo6ssqkwpha.py +799 -0
- progress/SpecForge/cache/compiled_kernels/3p/c3pdrhexk4rwol7f5l5vh7n543dj6piq6gw5k66g2p4vlyhopnop.py +58 -0
- progress/SpecForge/cache/compiled_kernels/3p/d4e91f4bc49d9cfc59a03caa3a2e04988f99c358762e8d23eed306dbbe3eae25.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/3s/c3spq2k2yeawxvgwl4dczrad6qwkidiiyxz5xwsucqivwlx625g7.py +534 -0
- progress/SpecForge/cache/compiled_kernels/3u/c3umapah7vcozhvfk5uovlssor7v533y4crphqgd677nuoizbpvj.py +799 -0
- progress/SpecForge/cache/compiled_kernels/42/c424arzgjg22xrcyl4orsbfthh3vxddttchjdd7yswdd5pdxdhtv.py +1019 -0
- progress/SpecForge/cache/compiled_kernels/44/3728d77fd47f8b1056ec8670d5b1bd262db03ae9994292fee6203d32e3d9cd03.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/44/c44m5klhlzg7nfvzfelnbb3hjh2jwzh2e5yyk3vtcvhyw6rbnjo6.py +54 -0
- progress/SpecForge/cache/compiled_kernels/4d/c4d7fh2egdfps7aogbncwlp3ihfwtff243bbobq7vrxj2m2grl64.py +51 -0
- progress/SpecForge/cache/compiled_kernels/4d/fd68b3c1a3fd19883dc58697393b6044e6217afda9ea11f84bd620545197dd6b.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/4h/c4h32peoig2erjdxibxrq3sbpm533ci3z57ntqjhdemzxp2rhysl.py +799 -0
- progress/SpecForge/cache/compiled_kernels/4k/c4korm4huj2wookuw6gikboxrsp3m5yt45c7fxucyujswm5fgb3u.py +534 -0
- progress/SpecForge/cache/compiled_kernels/4x/31f9d1ee4882fe2005f02592ea2d9f20a1835b42c5baefd7795e8640f97fdc16.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/4x/c4xjhgyzut6anhrjeinspoinohfxvyl6skr4gd3vfrscrvsevmya.py +28 -0
- progress/SpecForge/cache/compiled_kernels/5b/c5blvz5sxoj2veuexokuub2zm2pg4l2nqbbny4rr2jhsiiyw6njy.py +534 -0
- progress/SpecForge/cache/compiled_kernels/5g/c5g7nnbi3zupsx7kdee2ed6g2fgrtd2jxyggsjpckfg5p7rps4qm.py +534 -0
- progress/SpecForge/cache/compiled_kernels/5j/3cc65a0fdb544c73efb7240355b77da3f1ab394b46f272fa923c368e6cc63c34.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/5j/c5j7yk5hlaaxs42qwjlmoczwtoukaw2dio2o6p7qfekdy5upikyv.py +54 -0
- progress/SpecForge/cache/compiled_kernels/5w/c5wutjfcact264ykgcamj2asvz4eqe3ygz47upjgib2qw5rnnihu.py +1019 -0
- progress/SpecForge/cache/compiled_kernels/5z/c5zh2j5k5rlsr5zd4tfbvoplpwmbtizbrldjq2hw4nndmjztlcuy.py +879 -0
- progress/SpecForge/cache/compiled_kernels/66/c66vzxeqjq4tywx6ezsscs3u3rb6yxac26rzmkwrlzc3kmkcnhlf.py +799 -0
- progress/SpecForge/cache/compiled_kernels/6b/c6b364ingwlftc5camjox4wdd5z5l4famigf6ojv4cyji4ju37fy.py +879 -0
- progress/SpecForge/cache/compiled_kernels/6g/826e9651ad1f65a7d666097ac7518bb4b4d3dee1984132523b860dd02b66fff1.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py +58 -0
- progress/SpecForge/cache/compiled_kernels/6o/c6ovzyfo6vkdwwzou6dtdvw7qjf65ifmzpcoltl2nx2xuluryjcy.py +48 -0
- progress/SpecForge/cache/compiled_kernels/6o/df004f0eefe2693a59f2bae06581f78ab07b5ce2ec28936911ef13f1152e2ec9.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/6u/c6ulsdn73forgosxqs5bes2cerczsehypg7jodd4snit3gcqp6el.py +27 -0
- progress/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py +534 -0
- progress/SpecForge/cache/compiled_kernels/6u/e5329724392dcdd68d88f082f57e8929de539c9aa187c3a314edefdc595437d5.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/7a/aee791ee3934869dfa55caee4270f116dc979737c2bbcce40af5d5394ccc9ac8.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/7a/c7a2brsshxp6zz4foe62t5ivwbd2dwr6ytjbhxp22vq2evdotx5z.py +28 -0
- progress/SpecForge/cache/compiled_kernels/7a/c7adkdqab5cvqxxnwnn5au23gorh2eg33cfxneh7bzb7untnuvpw.py +534 -0
- progress/SpecForge/cache/compiled_kernels/7i/c7ijsdt7wst5xe64qslxvdevuoxlscozrh6zigwqavztuz3rptdj.py +58 -0
- progress/SpecForge/cache/compiled_kernels/7i/e1454135615b9b6420e5ef4fe0804f1b1346f398b5afb34dbe8085f0e900c8aa.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/7n/c7n4jk5r4lsbq62vtrxzouvawlecnbfhy3owedw4ewuid7d56bjs.py +582 -0
- progress/SpecForge/cache/compiled_kernels/a3/ca3omjumwqpxxjrgphxuxva3yanssfkbnvrp3buqomyudb2eg4nc.py +534 -0
- progress/SpecForge/cache/compiled_kernels/ad/4a5ce6c582fc1ef37d4cf3003d603da533264d4c59e6e0cb171d0b7490f32260.best_config +1 -0
- progress/SpecForge/cache/compiled_kernels/ad/b7e210dbfa93430d766b2fdeddfba773a52faf0eca21a68632a8853e9c3ecaf4.best_config +1 -0
progress/SpecForge/cache/compiled_kernels/22/2cc448de6eeb5f6db19c2adf4fa08d257ac050432532b15d4b1b5f447657bb74.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"}
|
progress/SpecForge/cache/compiled_kernels/22/c225p5q54jhc2rfoccuzlgejscvq2in5jzxlzcilu44cplhbfreo.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 32768},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 26 |
+
tmp1 = 0.6931471805599453
|
| 27 |
+
tmp2 = tmp0 * tmp1
|
| 28 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/22/c22yrjhqxirm4hkhfziohi6cfktos4modosbdflw25tw7a5d5gy7.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 2555904, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 638976, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 638976, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 2555904, 79872, 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 2555904, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 638976, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = 624
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = 624
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 5
|
| 148 |
+
stride_kv_idx_h = 25
|
| 149 |
+
stride_kv_idx_m = 5
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 5
|
| 245 |
+
stride_q_idx_h = 25
|
| 246 |
+
stride_q_idx_n = 5
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 79872*off_hkv + 638976*off_zq
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = 624
|
| 385 |
+
KV_LEN = 624
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = 624
|
| 578 |
+
KV_LEN = 624
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/2h/c2hwir33itr7umd7f5wx6cpaiwom2wrrbqi5cyznolnljriqz7pk.py
ADDED
|
@@ -0,0 +1,1028 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['1_backward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/il/cilouthst7wssvb523k44wgbbccu5i25bgzvyf7k74iapy3kurts.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# Graph fragment:
|
| 41 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem]
|
| 42 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1]
|
| 43 |
+
# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0]
|
| 44 |
+
# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=tangents_2]
|
| 45 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 46 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 47 |
+
# return %buf0,%buf1
|
| 48 |
+
triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', '''
|
| 49 |
+
import triton
|
| 50 |
+
import triton.language as tl
|
| 51 |
+
|
| 52 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 53 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 54 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 55 |
+
triton_helpers.set_driver_to_gpu()
|
| 56 |
+
|
| 57 |
+
@triton_heuristics.reduction(
|
| 58 |
+
size_hints={'x': 65536, 'r0_': 128},
|
| 59 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 60 |
+
filename=__file__,
|
| 61 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 62 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 63 |
+
)
|
| 64 |
+
@triton.jit
|
| 65 |
+
def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 66 |
+
r0_numel = 128
|
| 67 |
+
rnumel = r0_numel
|
| 68 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 69 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 70 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 71 |
+
xmask = xindex < xnumel
|
| 72 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 73 |
+
rbase = r0_base
|
| 74 |
+
x0 = (xindex % ks0)
|
| 75 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 76 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 77 |
+
x3 = xindex
|
| 78 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 79 |
+
r0_index = r0_offset + r0_base
|
| 80 |
+
r0_mask = r0_index < r0_numel
|
| 81 |
+
roffset = r0_offset
|
| 82 |
+
rindex = r0_index
|
| 83 |
+
r0_2 = r0_index
|
| 84 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 85 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 86 |
+
tmp2 = tmp0 * tmp1
|
| 87 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 88 |
+
tmp5 = _tmp4 + tmp3
|
| 89 |
+
_tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
|
| 90 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 91 |
+
tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 92 |
+
tmp6 = tmp4.to(tl.float32)
|
| 93 |
+
tmp8 = 0.6931471805599453
|
| 94 |
+
tmp9 = tmp7 * tmp8
|
| 95 |
+
tmp10 = 1.4426950408889634
|
| 96 |
+
tmp11 = tmp9 * tmp10
|
| 97 |
+
tmp12 = tmp6 - tmp11
|
| 98 |
+
tl.store(out_ptr1 + (x3), tmp12, xmask)
|
| 99 |
+
''', device_str='cuda')
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/2d/c2dfd76xaqhs43mudrn3koh54vgu4ljjh5tphjwwwqhxcdloa7kl.py
|
| 103 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 104 |
+
# Source node to ATen node mapping:
|
| 105 |
+
# Graph fragment:
|
| 106 |
+
# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2]
|
| 107 |
+
# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4]
|
| 108 |
+
# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6]
|
| 109 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1]
|
| 110 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1]
|
| 111 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1]
|
| 112 |
+
# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3]
|
| 113 |
+
# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5]
|
| 114 |
+
# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:6" = PlaceHolder[target=primals_13]
|
| 115 |
+
# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:6" = PlaceHolder[target=primals_9]
|
| 116 |
+
# %primals_16 : Tensor "i32[1, 1, 13][13, 13, 1]cuda:6" = PlaceHolder[target=primals_16]
|
| 117 |
+
# %primals_17 : Tensor "i32[1, 1, 13, 13][169, 169, 13, 1]cuda:6" = PlaceHolder[target=primals_17]
|
| 118 |
+
# %primals_14 : Tensor "i32[1, 1, 13][13, 13, 1]cuda:6" = PlaceHolder[target=primals_14]
|
| 119 |
+
# %primals_15 : Tensor "i32[1, 1, 13, 13][169, 169, 13, 1]cuda:6" = PlaceHolder[target=primals_15]
|
| 120 |
+
# %primals_18 : Tensor "i32[1, 1, 13][13, 13, 1]cuda:6" = PlaceHolder[target=primals_18]
|
| 121 |
+
# %primals_19 : Tensor "i32[1, 1, 13, 13][169, 169, 13, 1]cuda:6" = PlaceHolder[target=primals_19]
|
| 122 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 123 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 124 |
+
# return %getitem_4
|
| 125 |
+
triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
|
| 126 |
+
import triton
|
| 127 |
+
import triton.language as tl
|
| 128 |
+
|
| 129 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 130 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 131 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 132 |
+
|
| 133 |
+
@triton_heuristics.template(
|
| 134 |
+
|
| 135 |
+
num_stages=3,
|
| 136 |
+
num_warps=8,
|
| 137 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 138 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 139 |
+
|
| 140 |
+
)
|
| 141 |
+
@triton.jit
|
| 142 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
|
| 143 |
+
PRESCALE_QK : tl.constexpr = False
|
| 144 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 145 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 146 |
+
WRITE_DQ : tl.constexpr = True
|
| 147 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 148 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 149 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 150 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 151 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 152 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 153 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 154 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 155 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 156 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 157 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 158 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 159 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 160 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 161 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 162 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 163 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 164 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 165 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 166 |
+
Q = arg_Q
|
| 167 |
+
K = arg_K
|
| 168 |
+
V = arg_V
|
| 169 |
+
LSE = arg_LSE
|
| 170 |
+
DELTA = arg_DELTA
|
| 171 |
+
DO = arg_DO
|
| 172 |
+
DQ = arg_DQ
|
| 173 |
+
DV = arg_DV
|
| 174 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 175 |
+
KV_IDX = arg_KV_IDX
|
| 176 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 177 |
+
Q_IDX = arg_Q_IDX
|
| 178 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 179 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 180 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 181 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 182 |
+
|
| 183 |
+
# Sub notation for this kernel:
|
| 184 |
+
#
|
| 185 |
+
# Q: Query, K: Key, V: Value
|
| 186 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 187 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 188 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 189 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 190 |
+
# inductor codegen
|
| 191 |
+
# M: Number of queries, N: Number of keys/values
|
| 192 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 193 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 194 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 195 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 196 |
+
# (Modifiable) Performance tuning options
|
| 197 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 198 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 199 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 200 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 201 |
+
#
|
| 202 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 203 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 204 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 205 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 206 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 207 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 208 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 209 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 210 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 211 |
+
|
| 212 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 213 |
+
# or involve a numerics vs. perf tradeoff
|
| 214 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 215 |
+
# about 20% more numerical error, but slightly faster.
|
| 216 |
+
|
| 217 |
+
# Define strides of inputs
|
| 218 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 219 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 220 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 221 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 222 |
+
|
| 223 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 224 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 225 |
+
|
| 226 |
+
ZQ = 1
|
| 227 |
+
HQ = 32
|
| 228 |
+
HKV = 8
|
| 229 |
+
Q_LEN = ks0
|
| 230 |
+
ZKV = 1
|
| 231 |
+
KV_LEN = ks1
|
| 232 |
+
|
| 233 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 234 |
+
|
| 235 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 236 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 237 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 238 |
+
|
| 239 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 240 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 241 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 242 |
+
|
| 243 |
+
SPARSE_Z = 1
|
| 244 |
+
SPARSE_HQ = 1
|
| 245 |
+
|
| 246 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 247 |
+
|
| 248 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 249 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 250 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 251 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 252 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 253 |
+
|
| 254 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 255 |
+
K += k_adj
|
| 256 |
+
V += v_adj
|
| 257 |
+
DV += dv_adj
|
| 258 |
+
|
| 259 |
+
RCP_LN2 = 1.44269504
|
| 260 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 261 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 262 |
+
|
| 263 |
+
if pid >= NUM_KV_BLOCKS:
|
| 264 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 265 |
+
# THIS BLOCK DOES DQ
|
| 266 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 267 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 268 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 269 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 270 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 271 |
+
stride_kv_num_blks_h = ks2
|
| 272 |
+
stride_kv_idx_h = ks3*ks4
|
| 273 |
+
stride_kv_idx_m = ks4
|
| 274 |
+
|
| 275 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 276 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 277 |
+
|
| 278 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 279 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 280 |
+
|
| 281 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 282 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 283 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 284 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 285 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 286 |
+
|
| 287 |
+
Q2 = Q + q_adj2
|
| 288 |
+
DO2 = DO + do_adj2
|
| 289 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 290 |
+
# if Q is broadcasted)
|
| 291 |
+
DQ2 = DQ + dq_adj2
|
| 292 |
+
LSE2 = LSE + off_chz2
|
| 293 |
+
DELTA2 = DELTA + off_chz2
|
| 294 |
+
|
| 295 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 296 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 297 |
+
|
| 298 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 299 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 300 |
+
|
| 301 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 302 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 303 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 304 |
+
|
| 305 |
+
if PRESCALE_QK:
|
| 306 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 307 |
+
|
| 308 |
+
if IS_DIVISIBLE:
|
| 309 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 310 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 311 |
+
else:
|
| 312 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 313 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 314 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 315 |
+
lse = lse[:, None]
|
| 316 |
+
|
| 317 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 318 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 319 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 320 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 321 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 322 |
+
|
| 323 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 324 |
+
dq = bwd_dq_inner(
|
| 325 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 326 |
+
K, V,
|
| 327 |
+
dq, q, do, Di, lse,
|
| 328 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 329 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 330 |
+
kv_indices, sparse_kv_num_blocks,
|
| 331 |
+
MATMUL_PRECISION,
|
| 332 |
+
IS_FULL_BLOCKS=False,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if HAS_FULL_BLOCKS:
|
| 336 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 337 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 338 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 339 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 340 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 341 |
+
|
| 342 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 343 |
+
dq = bwd_dq_inner(
|
| 344 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 345 |
+
K, V,
|
| 346 |
+
dq, q, do, Di, lse,
|
| 347 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 348 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 349 |
+
kv_indices, sparse_kv_num_blocks,
|
| 350 |
+
MATMUL_PRECISION,
|
| 351 |
+
IS_FULL_BLOCKS=True,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Write back dQ.
|
| 355 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 356 |
+
dq *= SM_SCALE
|
| 357 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 358 |
+
tl.store(dq_ptrs, dq)
|
| 359 |
+
else:
|
| 360 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 361 |
+
else:
|
| 362 |
+
# THIS BLOCK DOES DK & DV
|
| 363 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 364 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 365 |
+
|
| 366 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 367 |
+
|
| 368 |
+
stride_q_num_blks_h = 13
|
| 369 |
+
stride_q_idx_h = 169
|
| 370 |
+
stride_q_idx_n = 13
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 374 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 375 |
+
|
| 376 |
+
start_n1 = pid * BLOCK_N1
|
| 377 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 378 |
+
|
| 379 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 380 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 381 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 382 |
+
|
| 383 |
+
if PRESCALE_QK:
|
| 384 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 385 |
+
|
| 386 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 387 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 388 |
+
|
| 389 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 390 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 391 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 392 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 393 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 394 |
+
|
| 395 |
+
Q1 = Q + q_adj1
|
| 396 |
+
DO1 = DO + do_adj1
|
| 397 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 398 |
+
# if Q is broadcasted)
|
| 399 |
+
LSE1 = LSE + off_chz1
|
| 400 |
+
DELTA1 = DELTA + off_chz1
|
| 401 |
+
|
| 402 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 403 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 404 |
+
|
| 405 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 406 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 407 |
+
|
| 408 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 409 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 410 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 411 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 412 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 413 |
+
|
| 414 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 415 |
+
dk, dv = bwd_dkdv_inner(
|
| 416 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 417 |
+
Q1, DO1, DELTA1, LSE1,
|
| 418 |
+
dk, dv, k, v,
|
| 419 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 420 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 421 |
+
q_indices, sparse_q_num_blocks,
|
| 422 |
+
MATMUL_PRECISION,
|
| 423 |
+
IS_FULL_BLOCKS=False,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if HAS_FULL_BLOCKS:
|
| 428 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 429 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 430 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 431 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 432 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 433 |
+
|
| 434 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 435 |
+
dk, dv = bwd_dkdv_inner(
|
| 436 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 437 |
+
Q1, DO1, DELTA1, LSE1,
|
| 438 |
+
dk, dv, k, v,
|
| 439 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 440 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 441 |
+
q_indices, sparse_q_num_blocks,
|
| 442 |
+
MATMUL_PRECISION,
|
| 443 |
+
IS_FULL_BLOCKS=True,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Write back dV and dK.
|
| 447 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 448 |
+
|
| 449 |
+
index_n = offs_n1[:, None]
|
| 450 |
+
index_k = offs_k[None, :]
|
| 451 |
+
index_v = offs_v[None, :]
|
| 452 |
+
|
| 453 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 454 |
+
tl.store(dv_ptrs, dv)
|
| 455 |
+
else:
|
| 456 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 457 |
+
|
| 458 |
+
dk *= SM_SCALE
|
| 459 |
+
|
| 460 |
+
if SAFE_HEAD_DIM:
|
| 461 |
+
mask = index_n < KV_LEN
|
| 462 |
+
else:
|
| 463 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 464 |
+
|
| 465 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 466 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 467 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 468 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 469 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 470 |
+
|
| 471 |
+
@triton.jit
|
| 472 |
+
def bwd_dq_inner(
|
| 473 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 474 |
+
K, V, # pointers
|
| 475 |
+
dq, q, do, Di, lse,
|
| 476 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 477 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 478 |
+
kv_indices, sparse_kv_num_blocks,
|
| 479 |
+
MATMUL_PRECISION,
|
| 480 |
+
IS_FULL_BLOCKS,
|
| 481 |
+
):
|
| 482 |
+
PRESCALE_QK : tl.constexpr = False
|
| 483 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 484 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 485 |
+
WRITE_DQ : tl.constexpr = True
|
| 486 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 487 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 488 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 489 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 490 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 491 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 492 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 493 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 494 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 495 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 496 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 497 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 498 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 499 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 500 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 501 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 502 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 503 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 504 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 505 |
+
|
| 506 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 507 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 508 |
+
Q_LEN = ks0
|
| 509 |
+
KV_LEN = ks1
|
| 510 |
+
|
| 511 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 512 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 513 |
+
|
| 514 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 515 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 516 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 517 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 518 |
+
|
| 519 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 520 |
+
|
| 521 |
+
for start_n in range(0, hi):
|
| 522 |
+
dq = bwd_dq_block_mn(
|
| 523 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 524 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 525 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 526 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 527 |
+
kv_indices, sparse_kv_num_blocks,
|
| 528 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 529 |
+
IS_FULL_BLOCKS,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# Increment pointers.
|
| 533 |
+
offset = get_offset_for_next_block(
|
| 534 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 535 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
kT_ptrs += offset * stride_kn
|
| 539 |
+
vT_ptrs += offset * stride_vn
|
| 540 |
+
|
| 541 |
+
offs_n2 += offset
|
| 542 |
+
|
| 543 |
+
return dq
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@triton.jit
|
| 547 |
+
def bwd_dq_block_mn(
|
| 548 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 549 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 550 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 551 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 552 |
+
kv_indices, sparse_kv_num_blocks,
|
| 553 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 554 |
+
IS_FULL_BLOCKS,
|
| 555 |
+
):
|
| 556 |
+
PRESCALE_QK : tl.constexpr = False
|
| 557 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 558 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 559 |
+
WRITE_DQ : tl.constexpr = True
|
| 560 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 561 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 562 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 563 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 564 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 565 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 566 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 567 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 568 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 569 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 570 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 571 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 572 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 573 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 574 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 575 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 576 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 577 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 578 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# NB reversed order to since K is transposed
|
| 582 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 583 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 584 |
+
if not PRESCALE_QK:
|
| 585 |
+
qk *= SM_SCALE
|
| 586 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 587 |
+
pre_mod_scores = qk
|
| 588 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 589 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 590 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 591 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 592 |
+
|
| 593 |
+
tmp0 = (qk)
|
| 594 |
+
post_mod_scores = tmp0
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
if not IS_DIVISIBLE:
|
| 600 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 601 |
+
|
| 602 |
+
if not IS_FULL_BLOCKS:
|
| 603 |
+
tmp1 = (m)
|
| 604 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 605 |
+
tmp3 = tmp1 < tmp2
|
| 606 |
+
tmp4 = (n)
|
| 607 |
+
tmp5 = tmp4 <= tmp1
|
| 608 |
+
tmp6 = tmp3 & tmp5
|
| 609 |
+
tmp7 = tmp1 >= tmp2
|
| 610 |
+
tmp8 = tmp4 < tmp2
|
| 611 |
+
tmp9 = tmp7 & tmp8
|
| 612 |
+
tmp10 = tmp8 == 0
|
| 613 |
+
tmp11 = tmp7 & tmp10
|
| 614 |
+
tmp12 = tmp1 - tmp2
|
| 615 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 616 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 617 |
+
tmp15 = tmp4 - tmp2
|
| 618 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 619 |
+
tmp17 = tmp14 == tmp16
|
| 620 |
+
tmp18 = tmp11 & tmp17
|
| 621 |
+
tmp19 = tmp9 | tmp18
|
| 622 |
+
tmp20 = tmp6 | tmp19
|
| 623 |
+
mask_mod_output = tmp20
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
# apply mask for partial masked block
|
| 627 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 628 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 629 |
+
if not PRESCALE_QK:
|
| 630 |
+
post_mod_scores *= RCP_LN2
|
| 631 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 632 |
+
# Compute dP and dS.
|
| 633 |
+
# NB reversed order to since V is transposed
|
| 634 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 635 |
+
|
| 636 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 637 |
+
ds = p * (dp - Di[:, None])
|
| 638 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 639 |
+
tmp21 = (ds)
|
| 640 |
+
grad_scores = tmp21
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
if not IS_DIVISIBLE:
|
| 644 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 645 |
+
|
| 646 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 647 |
+
if WRITE_DQ:
|
| 648 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 649 |
+
|
| 650 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 651 |
+
ds = grad_scores
|
| 652 |
+
|
| 653 |
+
if not IS_FULL_BLOCKS:
|
| 654 |
+
# (grads) apply mask for partially unmasked block
|
| 655 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 656 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 657 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 658 |
+
# Compute dQ.
|
| 659 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 660 |
+
|
| 661 |
+
return dq
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
@triton.jit
|
| 665 |
+
def bwd_dkdv_inner(
|
| 666 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 667 |
+
Q, DO, DELTA, LSE, # pointers
|
| 668 |
+
dk, dv, k, v,
|
| 669 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 670 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 671 |
+
q_indices, sparse_q_num_blocks,
|
| 672 |
+
MATMUL_PRECISION,
|
| 673 |
+
IS_FULL_BLOCKS,
|
| 674 |
+
):
|
| 675 |
+
PRESCALE_QK : tl.constexpr = False
|
| 676 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 677 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 678 |
+
WRITE_DQ : tl.constexpr = True
|
| 679 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 680 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 681 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 682 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 683 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 684 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 685 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 686 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 687 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 688 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 689 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 690 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 691 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 692 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 693 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 694 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 695 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 696 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 697 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 698 |
+
|
| 699 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 700 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 701 |
+
Q_LEN = ks0
|
| 702 |
+
KV_LEN = ks1
|
| 703 |
+
|
| 704 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 705 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 706 |
+
|
| 707 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 708 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 709 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 710 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 711 |
+
|
| 712 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 713 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 714 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 715 |
+
|
| 716 |
+
for start_m in range(0, hi):
|
| 717 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 718 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 719 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 720 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 721 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 722 |
+
q_indices, sparse_q_num_blocks,
|
| 723 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 724 |
+
IS_FULL_BLOCKS,
|
| 725 |
+
)
|
| 726 |
+
# Increment pointers.
|
| 727 |
+
offset = get_offset_for_next_block(
|
| 728 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 729 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
qT_ptrs += offset * stride_qm
|
| 733 |
+
do_ptrs += offset * stride_dom
|
| 734 |
+
offs_m1 += offset
|
| 735 |
+
|
| 736 |
+
return dk, dv
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
@triton.jit
|
| 740 |
+
def bwd_dkdv_block_mn(
|
| 741 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 742 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 743 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 744 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 745 |
+
q_indices, sparse_q_num_blocks,
|
| 746 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 747 |
+
IS_FULL_BLOCKS,
|
| 748 |
+
):
|
| 749 |
+
PRESCALE_QK : tl.constexpr = False
|
| 750 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 751 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 752 |
+
WRITE_DQ : tl.constexpr = True
|
| 753 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 754 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 755 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 756 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 757 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 758 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 759 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 760 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 761 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 762 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 763 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 764 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 765 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 766 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 767 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 768 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 769 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 770 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 771 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
# NB reversed order since Q is transposed
|
| 775 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 776 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 777 |
+
if IS_DIVISIBLE:
|
| 778 |
+
lse = tl.load(LSE + offs_m1)
|
| 779 |
+
else:
|
| 780 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 781 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 782 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 783 |
+
if not PRESCALE_QK:
|
| 784 |
+
qkT *= SM_SCALE
|
| 785 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 786 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 787 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 788 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 789 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 790 |
+
|
| 791 |
+
pre_mod_scores = qkT
|
| 792 |
+
tmp22 = (qkT)
|
| 793 |
+
post_mod_scores = tmp22
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
if not IS_DIVISIBLE:
|
| 798 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 799 |
+
|
| 800 |
+
if not IS_FULL_BLOCKS:
|
| 801 |
+
tmp23 = (m)
|
| 802 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 803 |
+
tmp25 = tmp23 < tmp24
|
| 804 |
+
tmp26 = (n)
|
| 805 |
+
tmp27 = tmp26 <= tmp23
|
| 806 |
+
tmp28 = tmp25 & tmp27
|
| 807 |
+
tmp29 = tmp23 >= tmp24
|
| 808 |
+
tmp30 = tmp26 < tmp24
|
| 809 |
+
tmp31 = tmp29 & tmp30
|
| 810 |
+
tmp32 = tmp30 == 0
|
| 811 |
+
tmp33 = tmp29 & tmp32
|
| 812 |
+
tmp34 = tmp23 - tmp24
|
| 813 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 814 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 815 |
+
tmp37 = tmp26 - tmp24
|
| 816 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 817 |
+
tmp39 = tmp36 == tmp38
|
| 818 |
+
tmp40 = tmp33 & tmp39
|
| 819 |
+
tmp41 = tmp31 | tmp40
|
| 820 |
+
tmp42 = tmp28 | tmp41
|
| 821 |
+
mask_mod_output = tmp42
|
| 822 |
+
|
| 823 |
+
# (grads) apply mask for fully masked block
|
| 824 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 825 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 826 |
+
if not PRESCALE_QK:
|
| 827 |
+
post_mod_scores *= RCP_LN2
|
| 828 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 829 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 830 |
+
# Compute dV.
|
| 831 |
+
ppT = pT
|
| 832 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 833 |
+
if IS_DIVISIBLE:
|
| 834 |
+
Di = tl.load(DELTA + offs_m1)
|
| 835 |
+
else:
|
| 836 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 837 |
+
# Compute dP and dS.
|
| 838 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 839 |
+
dsT = pT * (dpT - Di[None, :])
|
| 840 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 841 |
+
tmp43 = (dsT)
|
| 842 |
+
grad_scores = tmp43
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
if not IS_DIVISIBLE:
|
| 847 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 848 |
+
|
| 849 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 850 |
+
if not WRITE_DQ:
|
| 851 |
+
idx_b = off_z
|
| 852 |
+
idx_h = off_hq
|
| 853 |
+
idx_m = m
|
| 854 |
+
idx_n = n
|
| 855 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 856 |
+
|
| 857 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 858 |
+
dsT = grad_scores
|
| 859 |
+
if not IS_FULL_BLOCKS:
|
| 860 |
+
# (grads) apply mask for partially unmasked block
|
| 861 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 862 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 863 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 864 |
+
|
| 865 |
+
return dk, dv
|
| 866 |
+
|
| 867 |
+
# Utility triton funcs
|
| 868 |
+
@triton.jit
|
| 869 |
+
def get_offset_for_next_block(
|
| 870 |
+
loop_iter, col_indices, total_blocks,
|
| 871 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 872 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 873 |
+
):
|
| 874 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 875 |
+
return BLOCK
|
| 876 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 877 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 878 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 879 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 880 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 881 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 882 |
+
return offset
|
| 883 |
+
|
| 884 |
+
@triton.jit
|
| 885 |
+
def get_bounded_indices(indices, max_len=None):
|
| 886 |
+
return indices % max_len if max_len is not None else indices
|
| 887 |
+
|
| 888 |
+
@triton.jit
|
| 889 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 890 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 891 |
+
return tl.load(block_ptr)
|
| 892 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 893 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 894 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 895 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 896 |
+
else:
|
| 897 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 898 |
+
|
| 899 |
+
@triton.jit
|
| 900 |
+
def load_checked_2d(
|
| 901 |
+
ptr,
|
| 902 |
+
offs_m,
|
| 903 |
+
offs_n,
|
| 904 |
+
stride_m,
|
| 905 |
+
stride_n,
|
| 906 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 907 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 908 |
+
M_LEN: tl.constexpr,
|
| 909 |
+
N_LEN: tl.constexpr,
|
| 910 |
+
):
|
| 911 |
+
# Calculate final pointer if strides are provided
|
| 912 |
+
if stride_m is not None and stride_n is not None:
|
| 913 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 914 |
+
|
| 915 |
+
# Handle all masking cases
|
| 916 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 917 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 918 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 919 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 920 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 921 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 922 |
+
else: # Both divisible
|
| 923 |
+
return tl.load(ptr)
|
| 924 |
+
''', device_str='cuda')
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
async_compile.wait(globals())
|
| 928 |
+
del async_compile
|
| 929 |
+
|
| 930 |
+
class Runner:
|
| 931 |
+
def __init__(self, partitions):
|
| 932 |
+
self.partitions = partitions
|
| 933 |
+
|
| 934 |
+
def recursively_apply_fns(self, fns):
|
| 935 |
+
new_callables = []
|
| 936 |
+
for fn, c in zip(fns, self.partitions):
|
| 937 |
+
new_callables.append(fn(c))
|
| 938 |
+
self.partitions = new_callables
|
| 939 |
+
|
| 940 |
+
def call(self, args):
|
| 941 |
+
primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args
|
| 942 |
+
args.clear()
|
| 943 |
+
s37 = primals_10
|
| 944 |
+
s0 = primals_11
|
| 945 |
+
s22 = primals_7
|
| 946 |
+
s72 = primals_8
|
| 947 |
+
s99 = primals_12
|
| 948 |
+
assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 949 |
+
assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 950 |
+
assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 951 |
+
assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
|
| 952 |
+
assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1))
|
| 953 |
+
assert_size_stride(primals_14, (1, 1, 13), (13, 13, 1))
|
| 954 |
+
assert_size_stride(primals_15, (1, 1, 13, 13), (169, 169, 13, 1))
|
| 955 |
+
assert_size_stride(primals_16, (1, 1, 13), (13, 13, 1))
|
| 956 |
+
assert_size_stride(primals_17, (1, 1, 13, 13), (169, 169, 13, 1))
|
| 957 |
+
assert_size_stride(primals_18, (1, 1, 13), (13, 13, 1))
|
| 958 |
+
assert_size_stride(primals_19, (1, 1, 13, 13), (169, 169, 13, 1))
|
| 959 |
+
assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 960 |
+
assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 961 |
+
assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
|
| 962 |
+
assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 963 |
+
with torch.cuda._DeviceGuard(6):
|
| 964 |
+
torch.cuda.set_device(6)
|
| 965 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 966 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 967 |
+
triton_red_fused_mul_0_xnumel = 32*s37
|
| 968 |
+
stream6 = get_raw_stream(6)
|
| 969 |
+
triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream6)
|
| 970 |
+
del getitem
|
| 971 |
+
del tangents_2
|
| 972 |
+
buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 973 |
+
buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 974 |
+
buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 975 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 976 |
+
stream6 = get_raw_stream(6)
|
| 977 |
+
triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream6)
|
| 978 |
+
del buf1
|
| 979 |
+
del getitem_1
|
| 980 |
+
del primals_13
|
| 981 |
+
del primals_14
|
| 982 |
+
del primals_15
|
| 983 |
+
del primals_16
|
| 984 |
+
del primals_17
|
| 985 |
+
del primals_18
|
| 986 |
+
del primals_19
|
| 987 |
+
del primals_2
|
| 988 |
+
del primals_4
|
| 989 |
+
del primals_6
|
| 990 |
+
del primals_9
|
| 991 |
+
del tangents_1
|
| 992 |
+
return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, )
|
| 993 |
+
|
| 994 |
+
runner = Runner(partitions=[])
|
| 995 |
+
call = runner.call
|
| 996 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 1000 |
+
from torch._dynamo.testing import rand_strided
|
| 1001 |
+
from torch._inductor.utils import print_performance
|
| 1002 |
+
primals_10 = 1568
|
| 1003 |
+
primals_11 = 1568
|
| 1004 |
+
primals_7 = 13
|
| 1005 |
+
primals_8 = 13
|
| 1006 |
+
primals_12 = 13
|
| 1007 |
+
primals_2 = rand_strided((1, 32, 1568, 128), (6422528, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
|
| 1008 |
+
primals_4 = rand_strided((1, 8, 1568, 128), (1605632, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16)
|
| 1009 |
+
primals_6 = rand_strided((1, 8, 1568, 128), (1605632, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16)
|
| 1010 |
+
primals_9 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1011 |
+
primals_13 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1012 |
+
primals_14 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1013 |
+
primals_15 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1014 |
+
primals_16 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1015 |
+
primals_17 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1016 |
+
primals_18 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1017 |
+
primals_19 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:6', dtype=torch.int32)
|
| 1018 |
+
getitem = rand_strided((1, 32, 1568, 128), (6422528, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
|
| 1019 |
+
getitem_1 = rand_strided((1, 32, 1568), (50176, 1568, 1), device='cuda:6', dtype=torch.float32)
|
| 1020 |
+
tangents_1 = rand_strided((1, 32, 1568, 128), (6422528, 200704, 128, 1), device='cuda:6', dtype=torch.bfloat16)
|
| 1021 |
+
tangents_2 = rand_strided((1, 32, 1568), (50176, 1568, 1), device='cuda:6', dtype=torch.float32)
|
| 1022 |
+
fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2])
|
| 1023 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
if __name__ == "__main__":
|
| 1027 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 1028 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2m/50bb8a1cf8ca03a72155f9f84fae7c10cfe881f9f7d97e1e3f94ec85776c3639.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"}
|
progress/SpecForge/cache/compiled_kernels/2m/c2mzrxons6jvrj3mv77db5xyv5m4z73mego5v77sl66o6wuh5dbk.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 65536},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 26 |
+
tmp1 = 0.6931471805599453
|
| 27 |
+
tmp2 = tmp0 * tmp1
|
| 28 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/2p/c2pjjgigeh2ro4r74dzlvlf7os5rhnmyche5rzwawor6zxb6rvk2.py
ADDED
|
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['1_backward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/t3/ct3idxec76qtl3vshvvwnyu6rebkbok6z5sogtvlz6a62sxeduqp.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# Graph fragment:
|
| 41 |
+
# %getitem : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem]
|
| 42 |
+
# %tangents_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 219136, 128, 1]cuda:2" = PlaceHolder[target=tangents_1]
|
| 43 |
+
# %buf0 : Tensor "bf16[1, 32, 1712][55296, 1728, 1]cuda:2" = PlaceHolder[target=buf0]
|
| 44 |
+
# %tangents_2 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=tangents_2]
|
| 45 |
+
# %mul_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 46 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1712, 1712, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 47 |
+
# return %buf0,%buf1
|
| 48 |
+
triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', '''
|
| 49 |
+
import triton
|
| 50 |
+
import triton.language as tl
|
| 51 |
+
|
| 52 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 53 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 54 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 55 |
+
triton_helpers.set_driver_to_gpu()
|
| 56 |
+
|
| 57 |
+
@triton_heuristics.reduction(
|
| 58 |
+
size_hints={'x': 65536, 'r0_': 128},
|
| 59 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 60 |
+
filename=__file__,
|
| 61 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
|
| 62 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 657408, 'r0_': 28049408}}
|
| 63 |
+
)
|
| 64 |
+
@triton.jit
|
| 65 |
+
def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 66 |
+
xnumel = 54784
|
| 67 |
+
r0_numel = 128
|
| 68 |
+
rnumel = r0_numel
|
| 69 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 70 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 71 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 72 |
+
xmask = xindex < xnumel
|
| 73 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 74 |
+
rbase = r0_base
|
| 75 |
+
x0 = (xindex % 1712)
|
| 76 |
+
x1 = xindex // 1712
|
| 77 |
+
x3 = xindex
|
| 78 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 79 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 80 |
+
r0_index = r0_offset + r0_base
|
| 81 |
+
r0_mask = r0_index < r0_numel
|
| 82 |
+
roffset = r0_offset
|
| 83 |
+
rindex = r0_index
|
| 84 |
+
r0_2 = r0_index
|
| 85 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 86 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 87 |
+
tmp2 = tmp0 * tmp1
|
| 88 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 89 |
+
tmp5 = _tmp4 + tmp3
|
| 90 |
+
_tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
|
| 91 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 92 |
+
tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last')
|
| 93 |
+
tmp6 = tmp4.to(tl.float32)
|
| 94 |
+
tmp8 = 0.6931471805599453
|
| 95 |
+
tmp9 = tmp7 * tmp8
|
| 96 |
+
tmp10 = 1.4426950408889634
|
| 97 |
+
tmp11 = tmp9 * tmp10
|
| 98 |
+
tmp12 = tmp6 - tmp11
|
| 99 |
+
tl.store(out_ptr1 + (x3), tmp12, xmask)
|
| 100 |
+
''', device_str='cuda')
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xs/cxsqqwycgcgtkcz7s77tzmxqhofti6tljwwkkkb7agpib7z2otzr.py
|
| 104 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 105 |
+
# Source node to ATen node mapping:
|
| 106 |
+
# Graph fragment:
|
| 107 |
+
# %primals_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1]
|
| 108 |
+
# %primals_2 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_2]
|
| 109 |
+
# %primals_3 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_3]
|
| 110 |
+
# %getitem_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=getitem_1]
|
| 111 |
+
# %buf1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=buf1]
|
| 112 |
+
# %tangents_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 219136, 128, 1]cuda:2" = PlaceHolder[target=tangents_1]
|
| 113 |
+
# %getitem_3 : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3]
|
| 114 |
+
# %getitem_5 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=getitem_5]
|
| 115 |
+
# %primals_5 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_5]
|
| 116 |
+
# %primals_4 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_4]
|
| 117 |
+
# %primals_8 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_8]
|
| 118 |
+
# %primals_9 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_9]
|
| 119 |
+
# %primals_6 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_6]
|
| 120 |
+
# %primals_7 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_7]
|
| 121 |
+
# %primals_10 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_10]
|
| 122 |
+
# %primals_11 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_11]
|
| 123 |
+
# %mul_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 124 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1712, 1712, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 125 |
+
# return %getitem_4
|
| 126 |
+
triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
|
| 127 |
+
import triton
|
| 128 |
+
import triton.language as tl
|
| 129 |
+
|
| 130 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 131 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 132 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 133 |
+
|
| 134 |
+
@triton_heuristics.template(
|
| 135 |
+
|
| 136 |
+
num_stages=3,
|
| 137 |
+
num_warps=8,
|
| 138 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 139 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 140 |
+
|
| 141 |
+
)
|
| 142 |
+
@triton.jit
|
| 143 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
|
| 144 |
+
PRESCALE_QK : tl.constexpr = False
|
| 145 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 146 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 147 |
+
WRITE_DQ : tl.constexpr = True
|
| 148 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 149 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 150 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 151 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 152 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 153 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 154 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 155 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 156 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 157 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 158 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 159 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 160 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 161 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 162 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 163 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 164 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 165 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 166 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 167 |
+
Q = arg_Q
|
| 168 |
+
K = arg_K
|
| 169 |
+
V = arg_V
|
| 170 |
+
LSE = arg_LSE
|
| 171 |
+
DELTA = arg_DELTA
|
| 172 |
+
DO = arg_DO
|
| 173 |
+
DQ = arg_DQ
|
| 174 |
+
DV = arg_DV
|
| 175 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 176 |
+
KV_IDX = arg_KV_IDX
|
| 177 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 178 |
+
Q_IDX = arg_Q_IDX
|
| 179 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 180 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 181 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 182 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 183 |
+
|
| 184 |
+
# Sub notation for this kernel:
|
| 185 |
+
#
|
| 186 |
+
# Q: Query, K: Key, V: Value
|
| 187 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 188 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 189 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 190 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 191 |
+
# inductor codegen
|
| 192 |
+
# M: Number of queries, N: Number of keys/values
|
| 193 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 194 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 195 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 196 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 197 |
+
# (Modifiable) Performance tuning options
|
| 198 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 199 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 200 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 201 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 202 |
+
#
|
| 203 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 204 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 205 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 206 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 207 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 208 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 209 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 210 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 211 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 212 |
+
|
| 213 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 214 |
+
# or involve a numerics vs. perf tradeoff
|
| 215 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 216 |
+
# about 20% more numerical error, but slightly faster.
|
| 217 |
+
|
| 218 |
+
# Define strides of inputs
|
| 219 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 7012352, 128, 4096, 1
|
| 220 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1753088, 128, 1024, 1
|
| 221 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1753088, 128, 1024, 1
|
| 222 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 7012352, 219136, 128, 1
|
| 223 |
+
|
| 224 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 7012352, 128, 4096, 1
|
| 225 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1753088, 128, 1024, 1
|
| 226 |
+
|
| 227 |
+
ZQ = 1
|
| 228 |
+
HQ = 32
|
| 229 |
+
HKV = 8
|
| 230 |
+
Q_LEN = 1712
|
| 231 |
+
ZKV = 1
|
| 232 |
+
KV_LEN = 1712
|
| 233 |
+
|
| 234 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 235 |
+
|
| 236 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 237 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 238 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 239 |
+
|
| 240 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 241 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 242 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 243 |
+
|
| 244 |
+
SPARSE_Z = 1
|
| 245 |
+
SPARSE_HQ = 1
|
| 246 |
+
|
| 247 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 248 |
+
|
| 249 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 250 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 251 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 252 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 253 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 254 |
+
|
| 255 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 256 |
+
K += k_adj
|
| 257 |
+
V += v_adj
|
| 258 |
+
DV += dv_adj
|
| 259 |
+
|
| 260 |
+
RCP_LN2 = 1.44269504
|
| 261 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 262 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 263 |
+
|
| 264 |
+
if pid >= NUM_KV_BLOCKS:
|
| 265 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 266 |
+
# THIS BLOCK DOES DQ
|
| 267 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 268 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 269 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 270 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 271 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 272 |
+
stride_kv_num_blks_h = 14
|
| 273 |
+
stride_kv_idx_h = 196
|
| 274 |
+
stride_kv_idx_m = 14
|
| 275 |
+
|
| 276 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 277 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 278 |
+
|
| 279 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 280 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 281 |
+
|
| 282 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 283 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 284 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 285 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 286 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 287 |
+
|
| 288 |
+
Q2 = Q + q_adj2
|
| 289 |
+
DO2 = DO + do_adj2
|
| 290 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 291 |
+
# if Q is broadcasted)
|
| 292 |
+
DQ2 = DQ + dq_adj2
|
| 293 |
+
LSE2 = LSE + off_chz2
|
| 294 |
+
DELTA2 = DELTA + off_chz2
|
| 295 |
+
|
| 296 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 297 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 298 |
+
|
| 299 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 300 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 301 |
+
|
| 302 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 303 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 304 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 305 |
+
|
| 306 |
+
if PRESCALE_QK:
|
| 307 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 308 |
+
|
| 309 |
+
if IS_DIVISIBLE:
|
| 310 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 311 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 312 |
+
else:
|
| 313 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 314 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 315 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 316 |
+
lse = lse[:, None]
|
| 317 |
+
|
| 318 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 319 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 320 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 321 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 322 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 323 |
+
|
| 324 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 325 |
+
dq = bwd_dq_inner(
|
| 326 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 327 |
+
K, V,
|
| 328 |
+
dq, q, do, Di, lse,
|
| 329 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 330 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 331 |
+
kv_indices, sparse_kv_num_blocks,
|
| 332 |
+
MATMUL_PRECISION,
|
| 333 |
+
IS_FULL_BLOCKS=False,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if HAS_FULL_BLOCKS:
|
| 337 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 338 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 339 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 340 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 341 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 342 |
+
|
| 343 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 344 |
+
dq = bwd_dq_inner(
|
| 345 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 346 |
+
K, V,
|
| 347 |
+
dq, q, do, Di, lse,
|
| 348 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 349 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 350 |
+
kv_indices, sparse_kv_num_blocks,
|
| 351 |
+
MATMUL_PRECISION,
|
| 352 |
+
IS_FULL_BLOCKS=True,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Write back dQ.
|
| 356 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 357 |
+
dq *= SM_SCALE
|
| 358 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 359 |
+
tl.store(dq_ptrs, dq)
|
| 360 |
+
else:
|
| 361 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 362 |
+
else:
|
| 363 |
+
# THIS BLOCK DOES DK & DV
|
| 364 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 365 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 366 |
+
|
| 367 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 368 |
+
|
| 369 |
+
stride_q_num_blks_h = 14
|
| 370 |
+
stride_q_idx_h = 196
|
| 371 |
+
stride_q_idx_n = 14
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 375 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 376 |
+
|
| 377 |
+
start_n1 = pid * BLOCK_N1
|
| 378 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 379 |
+
|
| 380 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 381 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 382 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 383 |
+
|
| 384 |
+
if PRESCALE_QK:
|
| 385 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 386 |
+
|
| 387 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 388 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 389 |
+
|
| 390 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 391 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 392 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 393 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 394 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 395 |
+
|
| 396 |
+
Q1 = Q + q_adj1
|
| 397 |
+
DO1 = DO + do_adj1
|
| 398 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 399 |
+
# if Q is broadcasted)
|
| 400 |
+
LSE1 = LSE + off_chz1
|
| 401 |
+
DELTA1 = DELTA + off_chz1
|
| 402 |
+
|
| 403 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 404 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 405 |
+
|
| 406 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 407 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 408 |
+
|
| 409 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 410 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 411 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 412 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 413 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 414 |
+
|
| 415 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 416 |
+
dk, dv = bwd_dkdv_inner(
|
| 417 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 418 |
+
Q1, DO1, DELTA1, LSE1,
|
| 419 |
+
dk, dv, k, v,
|
| 420 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 421 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 422 |
+
q_indices, sparse_q_num_blocks,
|
| 423 |
+
MATMUL_PRECISION,
|
| 424 |
+
IS_FULL_BLOCKS=False,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if HAS_FULL_BLOCKS:
|
| 429 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 430 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 431 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 432 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 433 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 434 |
+
|
| 435 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 436 |
+
dk, dv = bwd_dkdv_inner(
|
| 437 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 438 |
+
Q1, DO1, DELTA1, LSE1,
|
| 439 |
+
dk, dv, k, v,
|
| 440 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 441 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 442 |
+
q_indices, sparse_q_num_blocks,
|
| 443 |
+
MATMUL_PRECISION,
|
| 444 |
+
IS_FULL_BLOCKS=True,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Write back dV and dK.
|
| 448 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 449 |
+
|
| 450 |
+
index_n = offs_n1[:, None]
|
| 451 |
+
index_k = offs_k[None, :]
|
| 452 |
+
index_v = offs_v[None, :]
|
| 453 |
+
|
| 454 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 455 |
+
tl.store(dv_ptrs, dv)
|
| 456 |
+
else:
|
| 457 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 458 |
+
|
| 459 |
+
dk *= SM_SCALE
|
| 460 |
+
|
| 461 |
+
if SAFE_HEAD_DIM:
|
| 462 |
+
mask = index_n < KV_LEN
|
| 463 |
+
else:
|
| 464 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 465 |
+
|
| 466 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 467 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 468 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 469 |
+
xindex = index_k + 128*index_n + 219136*off_hkv + 1753088*off_zq
|
| 470 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 471 |
+
|
| 472 |
+
@triton.jit
|
| 473 |
+
def bwd_dq_inner(
|
| 474 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 475 |
+
K, V, # pointers
|
| 476 |
+
dq, q, do, Di, lse,
|
| 477 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 478 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 479 |
+
kv_indices, sparse_kv_num_blocks,
|
| 480 |
+
MATMUL_PRECISION,
|
| 481 |
+
IS_FULL_BLOCKS,
|
| 482 |
+
):
|
| 483 |
+
PRESCALE_QK : tl.constexpr = False
|
| 484 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 485 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 486 |
+
WRITE_DQ : tl.constexpr = True
|
| 487 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 488 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 489 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 490 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 491 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 492 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 493 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 494 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 495 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 496 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 497 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 498 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 499 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 500 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 501 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 502 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 503 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 504 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 505 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 506 |
+
|
| 507 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 508 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 509 |
+
Q_LEN = 1712
|
| 510 |
+
KV_LEN = 1712
|
| 511 |
+
|
| 512 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 513 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 514 |
+
|
| 515 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 516 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 517 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 518 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 519 |
+
|
| 520 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 521 |
+
|
| 522 |
+
for start_n in range(0, hi):
|
| 523 |
+
dq = bwd_dq_block_mn(
|
| 524 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 525 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 526 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 527 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 528 |
+
kv_indices, sparse_kv_num_blocks,
|
| 529 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 530 |
+
IS_FULL_BLOCKS,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# Increment pointers.
|
| 534 |
+
offset = get_offset_for_next_block(
|
| 535 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 536 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
kT_ptrs += offset * stride_kn
|
| 540 |
+
vT_ptrs += offset * stride_vn
|
| 541 |
+
|
| 542 |
+
offs_n2 += offset
|
| 543 |
+
|
| 544 |
+
return dq
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
@triton.jit
|
| 548 |
+
def bwd_dq_block_mn(
|
| 549 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 550 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 551 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 552 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 553 |
+
kv_indices, sparse_kv_num_blocks,
|
| 554 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 555 |
+
IS_FULL_BLOCKS,
|
| 556 |
+
):
|
| 557 |
+
PRESCALE_QK : tl.constexpr = False
|
| 558 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 559 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 560 |
+
WRITE_DQ : tl.constexpr = True
|
| 561 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 562 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 563 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 564 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 565 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 566 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 567 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 568 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 569 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 570 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 571 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 572 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 573 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 574 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 575 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 576 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 577 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 578 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 579 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# NB reversed order to since K is transposed
|
| 583 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 584 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 585 |
+
if not PRESCALE_QK:
|
| 586 |
+
qk *= SM_SCALE
|
| 587 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 588 |
+
pre_mod_scores = qk
|
| 589 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 590 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 591 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 592 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 593 |
+
|
| 594 |
+
tmp0 = (qk)
|
| 595 |
+
post_mod_scores = tmp0
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
if not IS_DIVISIBLE:
|
| 601 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 602 |
+
|
| 603 |
+
if not IS_FULL_BLOCKS:
|
| 604 |
+
tmp1 = (m)
|
| 605 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 606 |
+
tmp3 = tmp1 < tmp2
|
| 607 |
+
tmp4 = (n)
|
| 608 |
+
tmp5 = tmp4 <= tmp1
|
| 609 |
+
tmp6 = tmp3 & tmp5
|
| 610 |
+
tmp7 = tmp1 >= tmp2
|
| 611 |
+
tmp8 = tmp4 < tmp2
|
| 612 |
+
tmp9 = tmp7 & tmp8
|
| 613 |
+
tmp10 = tmp8 == 0
|
| 614 |
+
tmp11 = tmp7 & tmp10
|
| 615 |
+
tmp12 = tmp1 - tmp2
|
| 616 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 617 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 618 |
+
tmp15 = tmp4 - tmp2
|
| 619 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 620 |
+
tmp17 = tmp14 == tmp16
|
| 621 |
+
tmp18 = tmp11 & tmp17
|
| 622 |
+
tmp19 = tmp9 | tmp18
|
| 623 |
+
tmp20 = tmp6 | tmp19
|
| 624 |
+
mask_mod_output = tmp20
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# apply mask for partial masked block
|
| 628 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 629 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 630 |
+
if not PRESCALE_QK:
|
| 631 |
+
post_mod_scores *= RCP_LN2
|
| 632 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 633 |
+
# Compute dP and dS.
|
| 634 |
+
# NB reversed order to since V is transposed
|
| 635 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 636 |
+
|
| 637 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 638 |
+
ds = p * (dp - Di[:, None])
|
| 639 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 640 |
+
tmp21 = (ds)
|
| 641 |
+
grad_scores = tmp21
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
if not IS_DIVISIBLE:
|
| 645 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 646 |
+
|
| 647 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 648 |
+
if WRITE_DQ:
|
| 649 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 650 |
+
|
| 651 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 652 |
+
ds = grad_scores
|
| 653 |
+
|
| 654 |
+
if not IS_FULL_BLOCKS:
|
| 655 |
+
# (grads) apply mask for partially unmasked block
|
| 656 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 657 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 658 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 659 |
+
# Compute dQ.
|
| 660 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 661 |
+
|
| 662 |
+
return dq
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
@triton.jit
|
| 666 |
+
def bwd_dkdv_inner(
|
| 667 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 668 |
+
Q, DO, DELTA, LSE, # pointers
|
| 669 |
+
dk, dv, k, v,
|
| 670 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 671 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 672 |
+
q_indices, sparse_q_num_blocks,
|
| 673 |
+
MATMUL_PRECISION,
|
| 674 |
+
IS_FULL_BLOCKS,
|
| 675 |
+
):
|
| 676 |
+
PRESCALE_QK : tl.constexpr = False
|
| 677 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 678 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 679 |
+
WRITE_DQ : tl.constexpr = True
|
| 680 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 681 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 682 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 683 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 684 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 685 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 686 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 687 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 688 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 689 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 690 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 691 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 692 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 693 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 694 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 695 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 696 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 697 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 698 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 699 |
+
|
| 700 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 701 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 702 |
+
Q_LEN = 1712
|
| 703 |
+
KV_LEN = 1712
|
| 704 |
+
|
| 705 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 706 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 707 |
+
|
| 708 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 709 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 710 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 711 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 712 |
+
|
| 713 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 714 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 715 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 716 |
+
|
| 717 |
+
for start_m in range(0, hi):
|
| 718 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 719 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 720 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 721 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 722 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 723 |
+
q_indices, sparse_q_num_blocks,
|
| 724 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 725 |
+
IS_FULL_BLOCKS,
|
| 726 |
+
)
|
| 727 |
+
# Increment pointers.
|
| 728 |
+
offset = get_offset_for_next_block(
|
| 729 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 730 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
qT_ptrs += offset * stride_qm
|
| 734 |
+
do_ptrs += offset * stride_dom
|
| 735 |
+
offs_m1 += offset
|
| 736 |
+
|
| 737 |
+
return dk, dv
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
@triton.jit
|
| 741 |
+
def bwd_dkdv_block_mn(
|
| 742 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 743 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 744 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 745 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 746 |
+
q_indices, sparse_q_num_blocks,
|
| 747 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 748 |
+
IS_FULL_BLOCKS,
|
| 749 |
+
):
|
| 750 |
+
PRESCALE_QK : tl.constexpr = False
|
| 751 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 752 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 753 |
+
WRITE_DQ : tl.constexpr = True
|
| 754 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 755 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 756 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 757 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 758 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 759 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 760 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 761 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 762 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 763 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 764 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 765 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 766 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 767 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 768 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 769 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 770 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 771 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 772 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
# NB reversed order since Q is transposed
|
| 776 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 777 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 778 |
+
if IS_DIVISIBLE:
|
| 779 |
+
lse = tl.load(LSE + offs_m1)
|
| 780 |
+
else:
|
| 781 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 782 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 783 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 784 |
+
if not PRESCALE_QK:
|
| 785 |
+
qkT *= SM_SCALE
|
| 786 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 787 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 788 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 789 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 790 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 791 |
+
|
| 792 |
+
pre_mod_scores = qkT
|
| 793 |
+
tmp22 = (qkT)
|
| 794 |
+
post_mod_scores = tmp22
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
if not IS_DIVISIBLE:
|
| 799 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 800 |
+
|
| 801 |
+
if not IS_FULL_BLOCKS:
|
| 802 |
+
tmp23 = (m)
|
| 803 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 804 |
+
tmp25 = tmp23 < tmp24
|
| 805 |
+
tmp26 = (n)
|
| 806 |
+
tmp27 = tmp26 <= tmp23
|
| 807 |
+
tmp28 = tmp25 & tmp27
|
| 808 |
+
tmp29 = tmp23 >= tmp24
|
| 809 |
+
tmp30 = tmp26 < tmp24
|
| 810 |
+
tmp31 = tmp29 & tmp30
|
| 811 |
+
tmp32 = tmp30 == 0
|
| 812 |
+
tmp33 = tmp29 & tmp32
|
| 813 |
+
tmp34 = tmp23 - tmp24
|
| 814 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 815 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 816 |
+
tmp37 = tmp26 - tmp24
|
| 817 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 818 |
+
tmp39 = tmp36 == tmp38
|
| 819 |
+
tmp40 = tmp33 & tmp39
|
| 820 |
+
tmp41 = tmp31 | tmp40
|
| 821 |
+
tmp42 = tmp28 | tmp41
|
| 822 |
+
mask_mod_output = tmp42
|
| 823 |
+
|
| 824 |
+
# (grads) apply mask for fully masked block
|
| 825 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 826 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 827 |
+
if not PRESCALE_QK:
|
| 828 |
+
post_mod_scores *= RCP_LN2
|
| 829 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 830 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 831 |
+
# Compute dV.
|
| 832 |
+
ppT = pT
|
| 833 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 834 |
+
if IS_DIVISIBLE:
|
| 835 |
+
Di = tl.load(DELTA + offs_m1)
|
| 836 |
+
else:
|
| 837 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 838 |
+
# Compute dP and dS.
|
| 839 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 840 |
+
dsT = pT * (dpT - Di[None, :])
|
| 841 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 842 |
+
tmp43 = (dsT)
|
| 843 |
+
grad_scores = tmp43
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
if not IS_DIVISIBLE:
|
| 848 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 849 |
+
|
| 850 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 851 |
+
if not WRITE_DQ:
|
| 852 |
+
idx_b = off_z
|
| 853 |
+
idx_h = off_hq
|
| 854 |
+
idx_m = m
|
| 855 |
+
idx_n = n
|
| 856 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 857 |
+
|
| 858 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 859 |
+
dsT = grad_scores
|
| 860 |
+
if not IS_FULL_BLOCKS:
|
| 861 |
+
# (grads) apply mask for partially unmasked block
|
| 862 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 863 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 864 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 865 |
+
|
| 866 |
+
return dk, dv
|
| 867 |
+
|
| 868 |
+
# Utility triton funcs
|
| 869 |
+
@triton.jit
|
| 870 |
+
def get_offset_for_next_block(
|
| 871 |
+
loop_iter, col_indices, total_blocks,
|
| 872 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 873 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 874 |
+
):
|
| 875 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 876 |
+
return BLOCK
|
| 877 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 878 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 879 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 880 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 881 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 882 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 883 |
+
return offset
|
| 884 |
+
|
| 885 |
+
@triton.jit
|
| 886 |
+
def get_bounded_indices(indices, max_len=None):
|
| 887 |
+
return indices % max_len if max_len is not None else indices
|
| 888 |
+
|
| 889 |
+
@triton.jit
|
| 890 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 891 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 892 |
+
return tl.load(block_ptr)
|
| 893 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 894 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 895 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 896 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 897 |
+
else:
|
| 898 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 899 |
+
|
| 900 |
+
@triton.jit
|
| 901 |
+
def load_checked_2d(
|
| 902 |
+
ptr,
|
| 903 |
+
offs_m,
|
| 904 |
+
offs_n,
|
| 905 |
+
stride_m,
|
| 906 |
+
stride_n,
|
| 907 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 908 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 909 |
+
M_LEN: tl.constexpr,
|
| 910 |
+
N_LEN: tl.constexpr,
|
| 911 |
+
):
|
| 912 |
+
# Calculate final pointer if strides are provided
|
| 913 |
+
if stride_m is not None and stride_n is not None:
|
| 914 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 915 |
+
|
| 916 |
+
# Handle all masking cases
|
| 917 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 918 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 919 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 920 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 921 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 922 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 923 |
+
else: # Both divisible
|
| 924 |
+
return tl.load(ptr)
|
| 925 |
+
''', device_str='cuda')
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
async_compile.wait(globals())
|
| 929 |
+
del async_compile
|
| 930 |
+
|
| 931 |
+
class Runner:
|
| 932 |
+
def __init__(self, partitions):
|
| 933 |
+
self.partitions = partitions
|
| 934 |
+
|
| 935 |
+
def recursively_apply_fns(self, fns):
|
| 936 |
+
new_callables = []
|
| 937 |
+
for fn, c in zip(fns, self.partitions):
|
| 938 |
+
new_callables.append(fn(c))
|
| 939 |
+
self.partitions = new_callables
|
| 940 |
+
|
| 941 |
+
def call(self, args):
|
| 942 |
+
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args
|
| 943 |
+
args.clear()
|
| 944 |
+
assert_size_stride(primals_1, (1, 32, 1712, 128), (7012352, 128, 4096, 1))
|
| 945 |
+
assert_size_stride(primals_2, (1, 8, 1712, 128), (1753088, 128, 1024, 1))
|
| 946 |
+
assert_size_stride(primals_3, (1, 8, 1712, 128), (1753088, 128, 1024, 1))
|
| 947 |
+
assert_size_stride(primals_4, (1, 1, 14, 14), (196, 196, 14, 1))
|
| 948 |
+
assert_size_stride(primals_5, (1, 1, 14), (14, 14, 1))
|
| 949 |
+
assert_size_stride(primals_6, (1, 1, 14), (14, 14, 1))
|
| 950 |
+
assert_size_stride(primals_7, (1, 1, 14, 14), (196, 196, 14, 1))
|
| 951 |
+
assert_size_stride(primals_8, (1, 1, 14), (14, 14, 1))
|
| 952 |
+
assert_size_stride(primals_9, (1, 1, 14, 14), (196, 196, 14, 1))
|
| 953 |
+
assert_size_stride(primals_10, (1, 1, 14), (14, 14, 1))
|
| 954 |
+
assert_size_stride(primals_11, (1, 1, 14, 14), (196, 196, 14, 1))
|
| 955 |
+
assert_size_stride(getitem, (1, 32, 1712, 128), (7012352, 128, 4096, 1))
|
| 956 |
+
assert_size_stride(getitem_1, (1, 32, 1712), (54784, 1712, 1))
|
| 957 |
+
assert_size_stride(tangents_1, (1, 32, 1712, 128), (7012352, 219136, 128, 1))
|
| 958 |
+
assert_size_stride(tangents_2, (1, 32, 1712), (54784, 1712, 1))
|
| 959 |
+
with torch.cuda._DeviceGuard(2):
|
| 960 |
+
torch.cuda.set_device(2)
|
| 961 |
+
buf1 = empty_strided_cuda((1, 32, 1712), (54784, 1712, 1), torch.float32)
|
| 962 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 963 |
+
stream2 = get_raw_stream(2)
|
| 964 |
+
triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 54784, 128, stream=stream2)
|
| 965 |
+
del getitem
|
| 966 |
+
del tangents_2
|
| 967 |
+
buf3 = empty_strided_cuda((1, 32, 1712, 128), (7012352, 128, 4096, 1), torch.bfloat16)
|
| 968 |
+
buf4 = empty_strided_cuda((1, 8, 1712, 128), (1753088, 128, 1024, 1), torch.bfloat16)
|
| 969 |
+
buf5 = empty_strided_cuda((1, 8, 1712, 128), (1753088, 128, 1024, 1), torch.bfloat16)
|
| 970 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 971 |
+
stream2 = get_raw_stream(2)
|
| 972 |
+
triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 70, 1, 8, stream=stream2)
|
| 973 |
+
del buf1
|
| 974 |
+
del getitem_1
|
| 975 |
+
del primals_1
|
| 976 |
+
del primals_10
|
| 977 |
+
del primals_11
|
| 978 |
+
del primals_2
|
| 979 |
+
del primals_3
|
| 980 |
+
del primals_4
|
| 981 |
+
del primals_5
|
| 982 |
+
del primals_6
|
| 983 |
+
del primals_7
|
| 984 |
+
del primals_8
|
| 985 |
+
del primals_9
|
| 986 |
+
del tangents_1
|
| 987 |
+
return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, )
|
| 988 |
+
|
| 989 |
+
runner = Runner(partitions=[])
|
| 990 |
+
call = runner.call
|
| 991 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 995 |
+
from torch._dynamo.testing import rand_strided
|
| 996 |
+
from torch._inductor.utils import print_performance
|
| 997 |
+
primals_1 = rand_strided((1, 32, 1712, 128), (7012352, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 998 |
+
primals_2 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 999 |
+
primals_3 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 1000 |
+
primals_4 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1001 |
+
primals_5 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1002 |
+
primals_6 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1003 |
+
primals_7 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1004 |
+
primals_8 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1005 |
+
primals_9 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1006 |
+
primals_10 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1007 |
+
primals_11 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32)
|
| 1008 |
+
getitem = rand_strided((1, 32, 1712, 128), (7012352, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 1009 |
+
getitem_1 = rand_strided((1, 32, 1712), (54784, 1712, 1), device='cuda:2', dtype=torch.float32)
|
| 1010 |
+
tangents_1 = rand_strided((1, 32, 1712, 128), (7012352, 219136, 128, 1), device='cuda:2', dtype=torch.bfloat16)
|
| 1011 |
+
tangents_2 = rand_strided((1, 32, 1712), (54784, 1712, 1), device='cuda:2', dtype=torch.float32)
|
| 1012 |
+
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2])
|
| 1013 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
if __name__ == "__main__":
|
| 1017 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 1018 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2x/c2xecscuz5jhvznv7jn4k545b7kcexuko5lz3em6woeo7u2ftonz.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 2048, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_1 = r0_index
|
| 32 |
+
x0 = xindex
|
| 33 |
+
x2 = (xindex % ks0)
|
| 34 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 36 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 37 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 38 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 39 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 40 |
+
tmp6 = float("-inf")
|
| 41 |
+
tmp7 = tmp4 == tmp6
|
| 42 |
+
tmp8 = tmp0 - tmp4
|
| 43 |
+
tmp9 = 0.0
|
| 44 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 45 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 46 |
+
tmp12 = tmp5 * tmp11
|
| 47 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 48 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 49 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp17 = 1.0
|
| 51 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 52 |
+
tmp19 = libdevice.log2(tmp18)
|
| 53 |
+
tmp20 = tmp19 + tmp4
|
| 54 |
+
tmp21 = 0.6931471805599453
|
| 55 |
+
tmp22 = tmp20 * tmp21
|
| 56 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 57 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 58 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
progress/SpecForge/cache/compiled_kernels/2x/c2ximikyisa7xxnki36flzcsdr4ziwruq7ujf3zymsuxon5pqv57.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['4_forward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ue/cueevzv7vmofb7xgliazywafcctru3ytzoxylqvshvpe6nweecz6.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2]
|
| 43 |
+
# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4]
|
| 44 |
+
# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6]
|
| 45 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=getitem_1]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1]
|
| 47 |
+
# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_10]
|
| 48 |
+
# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_7]
|
| 49 |
+
# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_11]
|
| 50 |
+
# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_12]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %getitem
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=8,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 80 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 81 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 82 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 83 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 84 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 86 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 87 |
+
USE_TMA : tl.constexpr = False
|
| 88 |
+
BLOCK_M : tl.constexpr = 128
|
| 89 |
+
BLOCK_N : tl.constexpr = 64
|
| 90 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 91 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 92 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 93 |
+
Q = arg_Q
|
| 94 |
+
K = arg_K
|
| 95 |
+
V = arg_V
|
| 96 |
+
LSE = arg_LSE
|
| 97 |
+
MAX = arg_MAX
|
| 98 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 99 |
+
KV_IDX = arg_KV_IDX
|
| 100 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 101 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 102 |
+
|
| 103 |
+
# Sub notation for this kernel:
|
| 104 |
+
#
|
| 105 |
+
# Q: Query, K: Key, V: Value
|
| 106 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 107 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 108 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 109 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 110 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 111 |
+
#
|
| 112 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 113 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 114 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 115 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 116 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 117 |
+
#
|
| 118 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 119 |
+
#
|
| 120 |
+
# (Modifiable) Performance tuning options
|
| 121 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 122 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 123 |
+
|
| 124 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 125 |
+
# or involve a numerics vs. perf tradeoff
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 127 |
+
# about 20% more numerical error, but slightly faster.
|
| 128 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 129 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 130 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 131 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 132 |
+
|
| 133 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 134 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 135 |
+
|
| 136 |
+
# Define strides of inputs
|
| 137 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 138 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 139 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 140 |
+
|
| 141 |
+
ZQ = 1
|
| 142 |
+
HQ = 32
|
| 143 |
+
Q_LEN = ks0
|
| 144 |
+
ZKV = 1
|
| 145 |
+
KV_LEN = ks1
|
| 146 |
+
|
| 147 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 148 |
+
|
| 149 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 150 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 151 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 152 |
+
|
| 153 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 154 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 155 |
+
off_zkv = off_zq % ZKV
|
| 156 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 157 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 158 |
+
|
| 159 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 160 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 161 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 162 |
+
|
| 163 |
+
Q = Q + q_offset
|
| 164 |
+
K = K + k_offset
|
| 165 |
+
V = V + v_offset
|
| 166 |
+
|
| 167 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 168 |
+
desc_q = None
|
| 169 |
+
desc_k = None
|
| 170 |
+
desc_v = None
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 176 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 179 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 180 |
+
|
| 181 |
+
stride_kv_num_blks_h = 1
|
| 182 |
+
stride_kv_idx_h = 1
|
| 183 |
+
stride_kv_idx_m = 1
|
| 184 |
+
|
| 185 |
+
# initialize pointer to m and l
|
| 186 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 187 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 188 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 189 |
+
|
| 190 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 191 |
+
|
| 192 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 193 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 194 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 195 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 196 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 197 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 198 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 199 |
+
|
| 200 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 201 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 202 |
+
# both score_mod and mask_mod to it
|
| 203 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 204 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 205 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 206 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# K and V pointers will be passed directly to forward_inner
|
| 210 |
+
|
| 211 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
acc, l_i, m_i = forward_inner(
|
| 215 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 216 |
+
q, K, V,
|
| 217 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 218 |
+
acc, l_i, m_i,
|
| 219 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 220 |
+
kv_start,
|
| 221 |
+
kv_indices, kv_num_blocks,
|
| 222 |
+
0, block_n_end,
|
| 223 |
+
MATMUL_PRECISION,
|
| 224 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 225 |
+
IS_FULL_BLOCKS=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 230 |
+
# apply mask_mod to them - only score_mod
|
| 231 |
+
if HAS_FULL_BLOCKS:
|
| 232 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 233 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 234 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 235 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 236 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 237 |
+
# K and V pointers will be passed directly to forward_inner
|
| 238 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 242 |
+
q, K, V,
|
| 243 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 246 |
+
kv_start,
|
| 247 |
+
kv_indices, kv_num_blocks,
|
| 248 |
+
0, block_n_end,
|
| 249 |
+
MATMUL_PRECISION,
|
| 250 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 251 |
+
IS_FULL_BLOCKS=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# [Note] Handle fully masked out rows:
|
| 256 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 257 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 258 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 259 |
+
|
| 260 |
+
acc = acc / l_i[:, None]
|
| 261 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 262 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 263 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 264 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 265 |
+
|
| 266 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 267 |
+
|
| 268 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 269 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 270 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 271 |
+
|
| 272 |
+
if OUTPUT_LOGSUMEXP:
|
| 273 |
+
off_hz = off_zq * HQ + off_hq
|
| 274 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 275 |
+
lse = m_i + tl.math.log2(l_i)
|
| 276 |
+
if IS_DIVISIBLE:
|
| 277 |
+
tl.store(l_ptrs, lse)
|
| 278 |
+
else:
|
| 279 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 280 |
+
|
| 281 |
+
if OUTPUT_MAX:
|
| 282 |
+
off_hz = off_zq * HQ + off_hq
|
| 283 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 284 |
+
if IS_DIVISIBLE:
|
| 285 |
+
tl.store(max_ptrs, m_i)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Utility triton funcs
|
| 291 |
+
@triton.jit
|
| 292 |
+
def get_offset_for_next_block(
|
| 293 |
+
loop_iter, col_indices, total_blocks,
|
| 294 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 295 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 296 |
+
):
|
| 297 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 298 |
+
return BLOCK
|
| 299 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 300 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 301 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 302 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 303 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 304 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 305 |
+
return offset
|
| 306 |
+
|
| 307 |
+
@triton.jit
|
| 308 |
+
def get_bounded_indices(indices, max_len=None):
|
| 309 |
+
return indices % max_len if max_len is not None else indices
|
| 310 |
+
|
| 311 |
+
@triton.jit
|
| 312 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 313 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 314 |
+
return tl.load(block_ptr)
|
| 315 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 316 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 317 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 318 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 319 |
+
else:
|
| 320 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 321 |
+
|
| 322 |
+
@triton.jit
|
| 323 |
+
def load_checked_2d(
|
| 324 |
+
ptr,
|
| 325 |
+
offs_m,
|
| 326 |
+
offs_n,
|
| 327 |
+
stride_m,
|
| 328 |
+
stride_n,
|
| 329 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 330 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 331 |
+
M_LEN: tl.constexpr,
|
| 332 |
+
N_LEN: tl.constexpr,
|
| 333 |
+
):
|
| 334 |
+
# Calculate final pointer if strides are provided
|
| 335 |
+
if stride_m is not None and stride_n is not None:
|
| 336 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 337 |
+
|
| 338 |
+
# Handle all masking cases
|
| 339 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 340 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 341 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 342 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 343 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 344 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 345 |
+
else: # Both divisible
|
| 346 |
+
return tl.load(ptr)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Common Imports
|
| 350 |
+
@triton.jit
|
| 351 |
+
def forward_block_mn(
|
| 352 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 353 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 354 |
+
# accumulated values
|
| 355 |
+
acc, l_i, m_i,
|
| 356 |
+
# Offsets
|
| 357 |
+
off_z, off_h, offs_m, offs_n,
|
| 358 |
+
# Offsets needed for TMA loads
|
| 359 |
+
kv_start,
|
| 360 |
+
kv_offset,
|
| 361 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 362 |
+
# Strides for K and V
|
| 363 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 364 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 365 |
+
|
| 366 |
+
):
|
| 367 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 368 |
+
PRESCALE_QK : tl.constexpr = False
|
| 369 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 370 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 371 |
+
WRITE_DQ : tl.constexpr = True
|
| 372 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 373 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 374 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 375 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 376 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 377 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 378 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 379 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 380 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 381 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 382 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 383 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 384 |
+
USE_TMA : tl.constexpr = False
|
| 385 |
+
BLOCK_M : tl.constexpr = 128
|
| 386 |
+
BLOCK_N : tl.constexpr = 64
|
| 387 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 388 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 389 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -- load k --
|
| 393 |
+
# NB reversed order to since K is transposed
|
| 394 |
+
kv_base_offset = kv_start + kv_offset
|
| 395 |
+
|
| 396 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 397 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 398 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 399 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 400 |
+
|
| 401 |
+
k = tl.trans(k)
|
| 402 |
+
# -- compute qk ---
|
| 403 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 404 |
+
if not PRESCALE_QK:
|
| 405 |
+
qk *= SM_SCALE
|
| 406 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 407 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 408 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 409 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 410 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 411 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 412 |
+
|
| 413 |
+
tmp0 = (qk)
|
| 414 |
+
post_mod_scores = tmp0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 418 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 419 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 420 |
+
|
| 421 |
+
if not IS_FULL_BLOCKS:
|
| 422 |
+
tmp1 = (m)
|
| 423 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 424 |
+
tmp3 = tmp1 < tmp2
|
| 425 |
+
tmp4 = (n)
|
| 426 |
+
tmp5 = tmp4 <= tmp1
|
| 427 |
+
tmp6 = tmp3 & tmp5
|
| 428 |
+
tmp7 = tmp1 >= tmp2
|
| 429 |
+
tmp8 = tmp4 < tmp2
|
| 430 |
+
tmp9 = tmp7 & tmp8
|
| 431 |
+
tmp10 = tmp8 == 0
|
| 432 |
+
tmp11 = tmp7 & tmp10
|
| 433 |
+
tmp12 = tmp1 - tmp2
|
| 434 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 435 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 436 |
+
tmp15 = tmp4 - tmp2
|
| 437 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 438 |
+
tmp17 = tmp14 == tmp16
|
| 439 |
+
tmp18 = tmp11 & tmp17
|
| 440 |
+
tmp19 = tmp9 | tmp18
|
| 441 |
+
tmp20 = tmp6 | tmp19
|
| 442 |
+
mask_mod_output = tmp20
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 446 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 447 |
+
# apply mask for partially unmasked blocks
|
| 448 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 449 |
+
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
post_mod_scores *= RCP_LN2
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
|
| 454 |
+
# -- compute scaling constant ---
|
| 455 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 456 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 457 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 458 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 459 |
+
else:
|
| 460 |
+
m_ij_masked = m_ij
|
| 461 |
+
|
| 462 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 463 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 464 |
+
|
| 465 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 466 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 467 |
+
# m_ij
|
| 468 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 469 |
+
# # -- scale and update acc --
|
| 470 |
+
acc = acc * alpha[:, None]
|
| 471 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 472 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 473 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 474 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 475 |
+
|
| 476 |
+
# -- update m_i
|
| 477 |
+
m_i = m_ij
|
| 478 |
+
|
| 479 |
+
return acc, l_i, m_i
|
| 480 |
+
|
| 481 |
+
@triton.jit
|
| 482 |
+
def forward_inner(
|
| 483 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 484 |
+
q, K, V,
|
| 485 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 486 |
+
# accumulated values
|
| 487 |
+
acc, l_i, m_i,
|
| 488 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 489 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 490 |
+
off_z, off_h, offs_m, offs_n,
|
| 491 |
+
# Offsets needed for TMA loads
|
| 492 |
+
kv_start,
|
| 493 |
+
# blocksparse data
|
| 494 |
+
kv_indices, kv_num_blocks,
|
| 495 |
+
# start kv and end kv block
|
| 496 |
+
block_n_start, block_n_end,
|
| 497 |
+
MATMUL_PRECISION,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
):
|
| 502 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 503 |
+
PRESCALE_QK : tl.constexpr = False
|
| 504 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 505 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 506 |
+
WRITE_DQ : tl.constexpr = True
|
| 507 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 508 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 509 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 510 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 511 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 512 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 513 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 514 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 515 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 516 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 517 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 518 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
BLOCK_M : tl.constexpr = 128
|
| 521 |
+
BLOCK_N : tl.constexpr = 64
|
| 522 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 523 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 524 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 528 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 529 |
+
|
| 530 |
+
if PRESCALE_QK:
|
| 531 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 532 |
+
|
| 533 |
+
kv_offset = 0
|
| 534 |
+
|
| 535 |
+
# loop over k, v and update accumulator until block_n_end
|
| 536 |
+
for start_n in range(block_n_start, block_n_end):
|
| 537 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 538 |
+
if IS_DIVISIBLE:
|
| 539 |
+
acc, l_i, m_i = forward_block_mn(
|
| 540 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 541 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 542 |
+
# accumulated values
|
| 543 |
+
acc, l_i, m_i,
|
| 544 |
+
# Offsets
|
| 545 |
+
off_z, off_h, offs_m, offs_n,
|
| 546 |
+
# Offsets needed for TMA loads
|
| 547 |
+
kv_start,
|
| 548 |
+
kv_offset,
|
| 549 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 550 |
+
# Strides for K and V
|
| 551 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 552 |
+
IS_FULL_BLOCKS,
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 556 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 557 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 558 |
+
# to the last block because it's faster a lot.
|
| 559 |
+
acc, l_i, m_i = forward_block_mn(
|
| 560 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 561 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 562 |
+
# accumulated values
|
| 563 |
+
acc, l_i, m_i,
|
| 564 |
+
# Offsets
|
| 565 |
+
off_z, off_h, offs_m, offs_n,
|
| 566 |
+
# Offsets needed for TMA loads
|
| 567 |
+
kv_start,
|
| 568 |
+
kv_offset,
|
| 569 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 570 |
+
# Strides for K and V
|
| 571 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 572 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
offset = get_offset_for_next_block(
|
| 578 |
+
start_n, kv_indices, kv_num_blocks,
|
| 579 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
offs_n = offs_n + offset
|
| 583 |
+
kv_offset += offset
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
return acc, l_i, m_i
|
| 587 |
+
''', device_str='cuda')
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/iv/civfnuj45ifcjf545fyz4ryg2q422dh2larygr2jr5yyli7gz6ae.py
|
| 591 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 592 |
+
# Source node to ATen node mapping:
|
| 593 |
+
# lse_scaled => mul_15
|
| 594 |
+
# Graph fragment:
|
| 595 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 596 |
+
# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 597 |
+
# return %mul_15
|
| 598 |
+
triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', '''
|
| 599 |
+
import triton
|
| 600 |
+
import triton.language as tl
|
| 601 |
+
|
| 602 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 603 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 604 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 605 |
+
triton_helpers.set_driver_to_gpu()
|
| 606 |
+
|
| 607 |
+
@triton_heuristics.pointwise(
|
| 608 |
+
size_hints={'x': 4096},
|
| 609 |
+
filename=__file__,
|
| 610 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 611 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 612 |
+
min_elem_per_thread=0
|
| 613 |
+
)
|
| 614 |
+
@triton.jit
|
| 615 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 616 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 617 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 618 |
+
xmask = xindex < xnumel
|
| 619 |
+
x2 = xindex
|
| 620 |
+
x0 = (xindex % ks0)
|
| 621 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 622 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 623 |
+
tmp1 = 0.6931471805599453
|
| 624 |
+
tmp2 = tmp0 * tmp1
|
| 625 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
| 626 |
+
''', device_str='cuda')
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
async_compile.wait(globals())
|
| 630 |
+
del async_compile
|
| 631 |
+
|
| 632 |
+
class Runner:
|
| 633 |
+
def __init__(self, partitions):
|
| 634 |
+
self.partitions = partitions
|
| 635 |
+
|
| 636 |
+
def recursively_apply_fns(self, fns):
|
| 637 |
+
new_callables = []
|
| 638 |
+
for fn, c in zip(fns, self.partitions):
|
| 639 |
+
new_callables.append(fn(c))
|
| 640 |
+
self.partitions = new_callables
|
| 641 |
+
|
| 642 |
+
def call(self, args):
|
| 643 |
+
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args
|
| 644 |
+
args.clear()
|
| 645 |
+
s50 = primals_1
|
| 646 |
+
s0 = primals_3
|
| 647 |
+
s43 = primals_5
|
| 648 |
+
s37 = primals_8
|
| 649 |
+
s71 = primals_9
|
| 650 |
+
assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 651 |
+
assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 652 |
+
assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 653 |
+
assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 654 |
+
assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
|
| 655 |
+
assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
|
| 656 |
+
assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 657 |
+
assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
|
| 658 |
+
assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 659 |
+
assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
|
| 660 |
+
assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 661 |
+
with torch.cuda._DeviceGuard(3):
|
| 662 |
+
torch.cuda.set_device(3)
|
| 663 |
+
buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 664 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 665 |
+
buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 666 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 667 |
+
stream3 = get_raw_stream(3)
|
| 668 |
+
triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream3)
|
| 669 |
+
del buf1
|
| 670 |
+
buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 671 |
+
# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul]
|
| 672 |
+
triton_poi_fused_mul_1_xnumel = 32*s37
|
| 673 |
+
stream3 = get_raw_stream(3)
|
| 674 |
+
triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream3)
|
| 675 |
+
return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, )
|
| 676 |
+
|
| 677 |
+
runner = Runner(partitions=[])
|
| 678 |
+
call = runner.call
|
| 679 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 683 |
+
from torch._dynamo.testing import rand_strided
|
| 684 |
+
from torch._inductor.utils import print_performance
|
| 685 |
+
primals_1 = 128
|
| 686 |
+
primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
|
| 687 |
+
primals_3 = 128
|
| 688 |
+
primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
|
| 689 |
+
primals_5 = 128
|
| 690 |
+
primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16)
|
| 691 |
+
primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 692 |
+
primals_8 = 128
|
| 693 |
+
primals_9 = 128
|
| 694 |
+
primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 695 |
+
primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 696 |
+
primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 697 |
+
primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 698 |
+
primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 699 |
+
primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 700 |
+
primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32)
|
| 701 |
+
fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16])
|
| 702 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
if __name__ == "__main__":
|
| 706 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 707 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/2x/dde0479cca0d878e6e0800ec13f7c80962354e837542bfc5f11f7b49306d323e.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"}
|
progress/SpecForge/cache/compiled_kernels/3h/6ee97c795357f97e7127237e15db9bd5fb14510b837eeb5094115cfaa1802d32.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
|
progress/SpecForge/cache/compiled_kernels/3h/c3h3fb5vqykgr7s3powfrnsc5alooplbijdgjizqo3xq5psavrvz.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 4096},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 26 |
+
tmp1 = 0.6931471805599453
|
| 27 |
+
tmp2 = tmp0 * tmp1
|
| 28 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/3h/c3hro2ygwh2ixqhmbrrdsjq6biaehv6lm5cbeo6yhlo6ssqkwpha.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 1
|
| 148 |
+
stride_kv_idx_h = 1
|
| 149 |
+
stride_kv_idx_m = 1
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 1
|
| 245 |
+
stride_q_idx_h = 1
|
| 246 |
+
stride_q_idx_n = 1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/3p/c3pdrhexk4rwol7f5l5vh7n543dj6piq6gw5k66g2p4vlyhopnop.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 2048, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_1 = r0_index
|
| 32 |
+
x0 = xindex
|
| 33 |
+
x2 = (xindex % ks0)
|
| 34 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 36 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 37 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 38 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 39 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 40 |
+
tmp6 = float("-inf")
|
| 41 |
+
tmp7 = tmp4 == tmp6
|
| 42 |
+
tmp8 = tmp0 - tmp4
|
| 43 |
+
tmp9 = 0.0
|
| 44 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 45 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 46 |
+
tmp12 = tmp5 * tmp11
|
| 47 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 48 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 49 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp17 = 1.0
|
| 51 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 52 |
+
tmp19 = libdevice.log2(tmp18)
|
| 53 |
+
tmp20 = tmp19 + tmp4
|
| 54 |
+
tmp21 = 0.6931471805599453
|
| 55 |
+
tmp22 = tmp20 * tmp21
|
| 56 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 57 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 58 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
progress/SpecForge/cache/compiled_kernels/3p/d4e91f4bc49d9cfc59a03caa3a2e04988f99c358762e8d23eed306dbbe3eae25.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 69, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"}
|
progress/SpecForge/cache/compiled_kernels/3s/c3spq2k2yeawxvgwl4dczrad6qwkidiiyxz5xwsucqivwlx625g7.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = ks2
|
| 130 |
+
stride_kv_idx_h = ks3*ks4
|
| 131 |
+
stride_kv_idx_m = ks4
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/3u/c3umapah7vcozhvfk5uovlssor7v533y4crphqgd677nuoizbpvj.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 1
|
| 148 |
+
stride_kv_idx_h = 1
|
| 149 |
+
stride_kv_idx_m = 1
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 1
|
| 245 |
+
stride_q_idx_h = 1
|
| 246 |
+
stride_q_idx_n = 1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/42/c424arzgjg22xrcyl4orsbfthh3vxddttchjdd7yswdd5pdxdhtv.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['3_backward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ns/cnsdsjpw2wa2anwrjd5vp4pkfq4mtbebhfmyzeol2hd4mzwu6qnk.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# Graph fragment:
|
| 41 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem]
|
| 42 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
|
| 43 |
+
# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
|
| 44 |
+
# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=tangents_2]
|
| 45 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 46 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 47 |
+
# return %buf0,%buf1
|
| 48 |
+
triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', '''
|
| 49 |
+
import triton
|
| 50 |
+
import triton.language as tl
|
| 51 |
+
|
| 52 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 53 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 54 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 55 |
+
triton_helpers.set_driver_to_gpu()
|
| 56 |
+
|
| 57 |
+
@triton_heuristics.persistent_reduction(
|
| 58 |
+
size_hints={'x': 4096, 'r0_': 128},
|
| 59 |
+
reduction_hint=ReductionHint.INNER,
|
| 60 |
+
filename=__file__,
|
| 61 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 62 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 63 |
+
)
|
| 64 |
+
@triton.jit
|
| 65 |
+
def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 66 |
+
r0_numel = 128
|
| 67 |
+
R0_BLOCK: tl.constexpr = 128
|
| 68 |
+
rnumel = r0_numel
|
| 69 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 70 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 71 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 72 |
+
xmask = xindex < xnumel
|
| 73 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 74 |
+
r0_offset = 0
|
| 75 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 76 |
+
roffset = r0_offset
|
| 77 |
+
rindex = r0_index
|
| 78 |
+
r0_2 = r0_index
|
| 79 |
+
x0 = (xindex % ks0)
|
| 80 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 81 |
+
x3 = xindex
|
| 82 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
|
| 83 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
|
| 84 |
+
tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 85 |
+
tmp2 = tmp0 * tmp1
|
| 86 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 87 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 88 |
+
tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
|
| 89 |
+
tmp7 = tmp6.to(tl.float32)
|
| 90 |
+
tmp9 = 0.6931471805599453
|
| 91 |
+
tmp10 = tmp8 * tmp9
|
| 92 |
+
tmp11 = 1.4426950408889634
|
| 93 |
+
tmp12 = tmp10 * tmp11
|
| 94 |
+
tmp13 = tmp7 - tmp12
|
| 95 |
+
tl.store(out_ptr1 + (x3), tmp13, xmask)
|
| 96 |
+
''', device_str='cuda')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py
|
| 100 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 101 |
+
# Source node to ATen node mapping:
|
| 102 |
+
# Graph fragment:
|
| 103 |
+
# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2]
|
| 104 |
+
# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4]
|
| 105 |
+
# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6]
|
| 106 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1]
|
| 107 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
|
| 108 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
|
| 109 |
+
# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3]
|
| 110 |
+
# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5]
|
| 111 |
+
# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_10]
|
| 112 |
+
# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_7]
|
| 113 |
+
# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_13]
|
| 114 |
+
# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_14]
|
| 115 |
+
# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_11]
|
| 116 |
+
# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_12]
|
| 117 |
+
# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_15]
|
| 118 |
+
# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_16]
|
| 119 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 120 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 121 |
+
# return %getitem_4
|
| 122 |
+
triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
|
| 123 |
+
import triton
|
| 124 |
+
import triton.language as tl
|
| 125 |
+
|
| 126 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 127 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 128 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 129 |
+
|
| 130 |
+
@triton_heuristics.template(
|
| 131 |
+
|
| 132 |
+
num_stages=3,
|
| 133 |
+
num_warps=8,
|
| 134 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 135 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 136 |
+
|
| 137 |
+
)
|
| 138 |
+
@triton.jit
|
| 139 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 140 |
+
PRESCALE_QK : tl.constexpr = False
|
| 141 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 142 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 143 |
+
WRITE_DQ : tl.constexpr = True
|
| 144 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 145 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 146 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 147 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 148 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 149 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 150 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 151 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 152 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 153 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 154 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 155 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 156 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 157 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 158 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 159 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 160 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 161 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 162 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 163 |
+
Q = arg_Q
|
| 164 |
+
K = arg_K
|
| 165 |
+
V = arg_V
|
| 166 |
+
LSE = arg_LSE
|
| 167 |
+
DELTA = arg_DELTA
|
| 168 |
+
DO = arg_DO
|
| 169 |
+
DQ = arg_DQ
|
| 170 |
+
DV = arg_DV
|
| 171 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 172 |
+
KV_IDX = arg_KV_IDX
|
| 173 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 174 |
+
Q_IDX = arg_Q_IDX
|
| 175 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 176 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 177 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 178 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 179 |
+
|
| 180 |
+
# Sub notation for this kernel:
|
| 181 |
+
#
|
| 182 |
+
# Q: Query, K: Key, V: Value
|
| 183 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 184 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 185 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 186 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 187 |
+
# inductor codegen
|
| 188 |
+
# M: Number of queries, N: Number of keys/values
|
| 189 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 190 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 191 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 192 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 193 |
+
# (Modifiable) Performance tuning options
|
| 194 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 195 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 196 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 197 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 198 |
+
#
|
| 199 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 200 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 201 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 202 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 203 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 204 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 205 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 206 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 207 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 208 |
+
|
| 209 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 210 |
+
# or involve a numerics vs. perf tradeoff
|
| 211 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 212 |
+
# about 20% more numerical error, but slightly faster.
|
| 213 |
+
|
| 214 |
+
# Define strides of inputs
|
| 215 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 216 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 217 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 218 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 219 |
+
|
| 220 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 221 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 222 |
+
|
| 223 |
+
ZQ = 1
|
| 224 |
+
HQ = 32
|
| 225 |
+
HKV = 8
|
| 226 |
+
Q_LEN = ks0
|
| 227 |
+
ZKV = 1
|
| 228 |
+
KV_LEN = ks1
|
| 229 |
+
|
| 230 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 231 |
+
|
| 232 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 233 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 234 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 235 |
+
|
| 236 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 237 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 238 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 239 |
+
|
| 240 |
+
SPARSE_Z = 1
|
| 241 |
+
SPARSE_HQ = 1
|
| 242 |
+
|
| 243 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 244 |
+
|
| 245 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 246 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 247 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 248 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 249 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 250 |
+
|
| 251 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 252 |
+
K += k_adj
|
| 253 |
+
V += v_adj
|
| 254 |
+
DV += dv_adj
|
| 255 |
+
|
| 256 |
+
RCP_LN2 = 1.44269504
|
| 257 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 258 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 259 |
+
|
| 260 |
+
if pid >= NUM_KV_BLOCKS:
|
| 261 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 262 |
+
# THIS BLOCK DOES DQ
|
| 263 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 264 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 265 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 266 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 267 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 268 |
+
stride_kv_num_blks_h = 1
|
| 269 |
+
stride_kv_idx_h = 1
|
| 270 |
+
stride_kv_idx_m = 1
|
| 271 |
+
|
| 272 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 273 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 274 |
+
|
| 275 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 276 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 277 |
+
|
| 278 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 279 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 280 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 281 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 282 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 283 |
+
|
| 284 |
+
Q2 = Q + q_adj2
|
| 285 |
+
DO2 = DO + do_adj2
|
| 286 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 287 |
+
# if Q is broadcasted)
|
| 288 |
+
DQ2 = DQ + dq_adj2
|
| 289 |
+
LSE2 = LSE + off_chz2
|
| 290 |
+
DELTA2 = DELTA + off_chz2
|
| 291 |
+
|
| 292 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 293 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 294 |
+
|
| 295 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 296 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 297 |
+
|
| 298 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 299 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 300 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 301 |
+
|
| 302 |
+
if PRESCALE_QK:
|
| 303 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 304 |
+
|
| 305 |
+
if IS_DIVISIBLE:
|
| 306 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 307 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 308 |
+
else:
|
| 309 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 310 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 311 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 312 |
+
lse = lse[:, None]
|
| 313 |
+
|
| 314 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 315 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 316 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 317 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 318 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 319 |
+
|
| 320 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 321 |
+
dq = bwd_dq_inner(
|
| 322 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 323 |
+
K, V,
|
| 324 |
+
dq, q, do, Di, lse,
|
| 325 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 326 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 327 |
+
kv_indices, sparse_kv_num_blocks,
|
| 328 |
+
MATMUL_PRECISION,
|
| 329 |
+
IS_FULL_BLOCKS=False,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if HAS_FULL_BLOCKS:
|
| 333 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 334 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 335 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 336 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 337 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 338 |
+
|
| 339 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 340 |
+
dq = bwd_dq_inner(
|
| 341 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 342 |
+
K, V,
|
| 343 |
+
dq, q, do, Di, lse,
|
| 344 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 345 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 346 |
+
kv_indices, sparse_kv_num_blocks,
|
| 347 |
+
MATMUL_PRECISION,
|
| 348 |
+
IS_FULL_BLOCKS=True,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Write back dQ.
|
| 352 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 353 |
+
dq *= SM_SCALE
|
| 354 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 355 |
+
tl.store(dq_ptrs, dq)
|
| 356 |
+
else:
|
| 357 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 358 |
+
else:
|
| 359 |
+
# THIS BLOCK DOES DK & DV
|
| 360 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 361 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 362 |
+
|
| 363 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 364 |
+
|
| 365 |
+
stride_q_num_blks_h = 1
|
| 366 |
+
stride_q_idx_h = 1
|
| 367 |
+
stride_q_idx_n = 1
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 371 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 372 |
+
|
| 373 |
+
start_n1 = pid * BLOCK_N1
|
| 374 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 375 |
+
|
| 376 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 377 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 378 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 379 |
+
|
| 380 |
+
if PRESCALE_QK:
|
| 381 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 382 |
+
|
| 383 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 384 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 385 |
+
|
| 386 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 387 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 388 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 389 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 390 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 391 |
+
|
| 392 |
+
Q1 = Q + q_adj1
|
| 393 |
+
DO1 = DO + do_adj1
|
| 394 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 395 |
+
# if Q is broadcasted)
|
| 396 |
+
LSE1 = LSE + off_chz1
|
| 397 |
+
DELTA1 = DELTA + off_chz1
|
| 398 |
+
|
| 399 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 400 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 401 |
+
|
| 402 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 403 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 404 |
+
|
| 405 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 406 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 407 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 408 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 409 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 410 |
+
|
| 411 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 412 |
+
dk, dv = bwd_dkdv_inner(
|
| 413 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 414 |
+
Q1, DO1, DELTA1, LSE1,
|
| 415 |
+
dk, dv, k, v,
|
| 416 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 417 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 418 |
+
q_indices, sparse_q_num_blocks,
|
| 419 |
+
MATMUL_PRECISION,
|
| 420 |
+
IS_FULL_BLOCKS=False,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if HAS_FULL_BLOCKS:
|
| 425 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 426 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 427 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 428 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 429 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 430 |
+
|
| 431 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 432 |
+
dk, dv = bwd_dkdv_inner(
|
| 433 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 434 |
+
Q1, DO1, DELTA1, LSE1,
|
| 435 |
+
dk, dv, k, v,
|
| 436 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 437 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 438 |
+
q_indices, sparse_q_num_blocks,
|
| 439 |
+
MATMUL_PRECISION,
|
| 440 |
+
IS_FULL_BLOCKS=True,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Write back dV and dK.
|
| 444 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 445 |
+
|
| 446 |
+
index_n = offs_n1[:, None]
|
| 447 |
+
index_k = offs_k[None, :]
|
| 448 |
+
index_v = offs_v[None, :]
|
| 449 |
+
|
| 450 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 451 |
+
tl.store(dv_ptrs, dv)
|
| 452 |
+
else:
|
| 453 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 454 |
+
|
| 455 |
+
dk *= SM_SCALE
|
| 456 |
+
|
| 457 |
+
if SAFE_HEAD_DIM:
|
| 458 |
+
mask = index_n < KV_LEN
|
| 459 |
+
else:
|
| 460 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 461 |
+
|
| 462 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 463 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 464 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 465 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 466 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 467 |
+
|
| 468 |
+
@triton.jit
|
| 469 |
+
def bwd_dq_inner(
|
| 470 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 471 |
+
K, V, # pointers
|
| 472 |
+
dq, q, do, Di, lse,
|
| 473 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 474 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 475 |
+
kv_indices, sparse_kv_num_blocks,
|
| 476 |
+
MATMUL_PRECISION,
|
| 477 |
+
IS_FULL_BLOCKS,
|
| 478 |
+
):
|
| 479 |
+
PRESCALE_QK : tl.constexpr = False
|
| 480 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 481 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 482 |
+
WRITE_DQ : tl.constexpr = True
|
| 483 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 484 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 485 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 486 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 487 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 488 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 489 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 490 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 491 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 492 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 493 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 494 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 495 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 496 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 497 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 498 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 499 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 500 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 501 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 502 |
+
|
| 503 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 504 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 505 |
+
Q_LEN = ks0
|
| 506 |
+
KV_LEN = ks1
|
| 507 |
+
|
| 508 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 509 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 510 |
+
|
| 511 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 512 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 513 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 514 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 515 |
+
|
| 516 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 517 |
+
|
| 518 |
+
for start_n in range(0, hi):
|
| 519 |
+
dq = bwd_dq_block_mn(
|
| 520 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 521 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 522 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 523 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 524 |
+
kv_indices, sparse_kv_num_blocks,
|
| 525 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 526 |
+
IS_FULL_BLOCKS,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# Increment pointers.
|
| 530 |
+
offset = get_offset_for_next_block(
|
| 531 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 532 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
kT_ptrs += offset * stride_kn
|
| 536 |
+
vT_ptrs += offset * stride_vn
|
| 537 |
+
|
| 538 |
+
offs_n2 += offset
|
| 539 |
+
|
| 540 |
+
return dq
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@triton.jit
|
| 544 |
+
def bwd_dq_block_mn(
|
| 545 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 546 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 547 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 548 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 549 |
+
kv_indices, sparse_kv_num_blocks,
|
| 550 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 551 |
+
IS_FULL_BLOCKS,
|
| 552 |
+
):
|
| 553 |
+
PRESCALE_QK : tl.constexpr = False
|
| 554 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 555 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 556 |
+
WRITE_DQ : tl.constexpr = True
|
| 557 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 558 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 559 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 560 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 561 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 562 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 563 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 564 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 567 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 568 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 569 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 570 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 571 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 572 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 573 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 574 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 575 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# NB reversed order to since K is transposed
|
| 579 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 580 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 581 |
+
if not PRESCALE_QK:
|
| 582 |
+
qk *= SM_SCALE
|
| 583 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 584 |
+
pre_mod_scores = qk
|
| 585 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 586 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 587 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 588 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 589 |
+
|
| 590 |
+
tmp0 = (qk)
|
| 591 |
+
post_mod_scores = tmp0
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
if not IS_DIVISIBLE:
|
| 597 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 598 |
+
|
| 599 |
+
if not IS_FULL_BLOCKS:
|
| 600 |
+
tmp1 = (m)
|
| 601 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 602 |
+
tmp3 = tmp1 < tmp2
|
| 603 |
+
tmp4 = (n)
|
| 604 |
+
tmp5 = tmp4 <= tmp1
|
| 605 |
+
tmp6 = tmp3 & tmp5
|
| 606 |
+
tmp7 = tmp1 >= tmp2
|
| 607 |
+
tmp8 = tmp4 < tmp2
|
| 608 |
+
tmp9 = tmp7 & tmp8
|
| 609 |
+
tmp10 = tmp8 == 0
|
| 610 |
+
tmp11 = tmp7 & tmp10
|
| 611 |
+
tmp12 = tmp1 - tmp2
|
| 612 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 613 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 614 |
+
tmp15 = tmp4 - tmp2
|
| 615 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 616 |
+
tmp17 = tmp14 == tmp16
|
| 617 |
+
tmp18 = tmp11 & tmp17
|
| 618 |
+
tmp19 = tmp9 | tmp18
|
| 619 |
+
tmp20 = tmp6 | tmp19
|
| 620 |
+
mask_mod_output = tmp20
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
# apply mask for partial masked block
|
| 624 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 625 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 626 |
+
if not PRESCALE_QK:
|
| 627 |
+
post_mod_scores *= RCP_LN2
|
| 628 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 629 |
+
# Compute dP and dS.
|
| 630 |
+
# NB reversed order to since V is transposed
|
| 631 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 632 |
+
|
| 633 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 634 |
+
ds = p * (dp - Di[:, None])
|
| 635 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 636 |
+
tmp21 = (ds)
|
| 637 |
+
grad_scores = tmp21
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
if not IS_DIVISIBLE:
|
| 641 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 642 |
+
|
| 643 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 644 |
+
if WRITE_DQ:
|
| 645 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 646 |
+
|
| 647 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 648 |
+
ds = grad_scores
|
| 649 |
+
|
| 650 |
+
if not IS_FULL_BLOCKS:
|
| 651 |
+
# (grads) apply mask for partially unmasked block
|
| 652 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 653 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 654 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 655 |
+
# Compute dQ.
|
| 656 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 657 |
+
|
| 658 |
+
return dq
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@triton.jit
|
| 662 |
+
def bwd_dkdv_inner(
|
| 663 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 664 |
+
Q, DO, DELTA, LSE, # pointers
|
| 665 |
+
dk, dv, k, v,
|
| 666 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 667 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 668 |
+
q_indices, sparse_q_num_blocks,
|
| 669 |
+
MATMUL_PRECISION,
|
| 670 |
+
IS_FULL_BLOCKS,
|
| 671 |
+
):
|
| 672 |
+
PRESCALE_QK : tl.constexpr = False
|
| 673 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 674 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 675 |
+
WRITE_DQ : tl.constexpr = True
|
| 676 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 677 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 678 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 679 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 680 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 681 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 682 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 683 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 684 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 685 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 686 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 687 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 688 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 689 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 690 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 691 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 692 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 693 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 694 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 695 |
+
|
| 696 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 697 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 698 |
+
Q_LEN = ks0
|
| 699 |
+
KV_LEN = ks1
|
| 700 |
+
|
| 701 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 702 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 703 |
+
|
| 704 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 705 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 706 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 707 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 708 |
+
|
| 709 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 710 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 711 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 712 |
+
|
| 713 |
+
for start_m in range(0, hi):
|
| 714 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 715 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 716 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 717 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 718 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 719 |
+
q_indices, sparse_q_num_blocks,
|
| 720 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 721 |
+
IS_FULL_BLOCKS,
|
| 722 |
+
)
|
| 723 |
+
# Increment pointers.
|
| 724 |
+
offset = get_offset_for_next_block(
|
| 725 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 726 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
qT_ptrs += offset * stride_qm
|
| 730 |
+
do_ptrs += offset * stride_dom
|
| 731 |
+
offs_m1 += offset
|
| 732 |
+
|
| 733 |
+
return dk, dv
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
@triton.jit
|
| 737 |
+
def bwd_dkdv_block_mn(
|
| 738 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 739 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 740 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 741 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 742 |
+
q_indices, sparse_q_num_blocks,
|
| 743 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 744 |
+
IS_FULL_BLOCKS,
|
| 745 |
+
):
|
| 746 |
+
PRESCALE_QK : tl.constexpr = False
|
| 747 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 749 |
+
WRITE_DQ : tl.constexpr = True
|
| 750 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 751 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 752 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 753 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 754 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 755 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 756 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 757 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 758 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 759 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 760 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 761 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 762 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 763 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 764 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 765 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 766 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 767 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 768 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# NB reversed order since Q is transposed
|
| 772 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 773 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 774 |
+
if IS_DIVISIBLE:
|
| 775 |
+
lse = tl.load(LSE + offs_m1)
|
| 776 |
+
else:
|
| 777 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 778 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 779 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 780 |
+
if not PRESCALE_QK:
|
| 781 |
+
qkT *= SM_SCALE
|
| 782 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 783 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 784 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 785 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 786 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 787 |
+
|
| 788 |
+
pre_mod_scores = qkT
|
| 789 |
+
tmp22 = (qkT)
|
| 790 |
+
post_mod_scores = tmp22
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
if not IS_DIVISIBLE:
|
| 795 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 796 |
+
|
| 797 |
+
if not IS_FULL_BLOCKS:
|
| 798 |
+
tmp23 = (m)
|
| 799 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 800 |
+
tmp25 = tmp23 < tmp24
|
| 801 |
+
tmp26 = (n)
|
| 802 |
+
tmp27 = tmp26 <= tmp23
|
| 803 |
+
tmp28 = tmp25 & tmp27
|
| 804 |
+
tmp29 = tmp23 >= tmp24
|
| 805 |
+
tmp30 = tmp26 < tmp24
|
| 806 |
+
tmp31 = tmp29 & tmp30
|
| 807 |
+
tmp32 = tmp30 == 0
|
| 808 |
+
tmp33 = tmp29 & tmp32
|
| 809 |
+
tmp34 = tmp23 - tmp24
|
| 810 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 811 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 812 |
+
tmp37 = tmp26 - tmp24
|
| 813 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 814 |
+
tmp39 = tmp36 == tmp38
|
| 815 |
+
tmp40 = tmp33 & tmp39
|
| 816 |
+
tmp41 = tmp31 | tmp40
|
| 817 |
+
tmp42 = tmp28 | tmp41
|
| 818 |
+
mask_mod_output = tmp42
|
| 819 |
+
|
| 820 |
+
# (grads) apply mask for fully masked block
|
| 821 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 822 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 823 |
+
if not PRESCALE_QK:
|
| 824 |
+
post_mod_scores *= RCP_LN2
|
| 825 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 826 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 827 |
+
# Compute dV.
|
| 828 |
+
ppT = pT
|
| 829 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 830 |
+
if IS_DIVISIBLE:
|
| 831 |
+
Di = tl.load(DELTA + offs_m1)
|
| 832 |
+
else:
|
| 833 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 834 |
+
# Compute dP and dS.
|
| 835 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 836 |
+
dsT = pT * (dpT - Di[None, :])
|
| 837 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 838 |
+
tmp43 = (dsT)
|
| 839 |
+
grad_scores = tmp43
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
if not IS_DIVISIBLE:
|
| 844 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 845 |
+
|
| 846 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 847 |
+
if not WRITE_DQ:
|
| 848 |
+
idx_b = off_z
|
| 849 |
+
idx_h = off_hq
|
| 850 |
+
idx_m = m
|
| 851 |
+
idx_n = n
|
| 852 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 853 |
+
|
| 854 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 855 |
+
dsT = grad_scores
|
| 856 |
+
if not IS_FULL_BLOCKS:
|
| 857 |
+
# (grads) apply mask for partially unmasked block
|
| 858 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 859 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 860 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 861 |
+
|
| 862 |
+
return dk, dv
|
| 863 |
+
|
| 864 |
+
# Utility triton funcs
|
| 865 |
+
@triton.jit
|
| 866 |
+
def get_offset_for_next_block(
|
| 867 |
+
loop_iter, col_indices, total_blocks,
|
| 868 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 869 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 870 |
+
):
|
| 871 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 872 |
+
return BLOCK
|
| 873 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 874 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 875 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 876 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 877 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 878 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 879 |
+
return offset
|
| 880 |
+
|
| 881 |
+
@triton.jit
|
| 882 |
+
def get_bounded_indices(indices, max_len=None):
|
| 883 |
+
return indices % max_len if max_len is not None else indices
|
| 884 |
+
|
| 885 |
+
@triton.jit
|
| 886 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 887 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 888 |
+
return tl.load(block_ptr)
|
| 889 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 890 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 891 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 892 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 893 |
+
else:
|
| 894 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 895 |
+
|
| 896 |
+
@triton.jit
|
| 897 |
+
def load_checked_2d(
|
| 898 |
+
ptr,
|
| 899 |
+
offs_m,
|
| 900 |
+
offs_n,
|
| 901 |
+
stride_m,
|
| 902 |
+
stride_n,
|
| 903 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 904 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 905 |
+
M_LEN: tl.constexpr,
|
| 906 |
+
N_LEN: tl.constexpr,
|
| 907 |
+
):
|
| 908 |
+
# Calculate final pointer if strides are provided
|
| 909 |
+
if stride_m is not None and stride_n is not None:
|
| 910 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 911 |
+
|
| 912 |
+
# Handle all masking cases
|
| 913 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 914 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 915 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 916 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 917 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 918 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 919 |
+
else: # Both divisible
|
| 920 |
+
return tl.load(ptr)
|
| 921 |
+
''', device_str='cuda')
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
async_compile.wait(globals())
|
| 925 |
+
del async_compile
|
| 926 |
+
|
| 927 |
+
class Runner:
|
| 928 |
+
def __init__(self, partitions):
|
| 929 |
+
self.partitions = partitions
|
| 930 |
+
|
| 931 |
+
def recursively_apply_fns(self, fns):
|
| 932 |
+
new_callables = []
|
| 933 |
+
for fn, c in zip(fns, self.partitions):
|
| 934 |
+
new_callables.append(fn(c))
|
| 935 |
+
self.partitions = new_callables
|
| 936 |
+
|
| 937 |
+
def call(self, args):
|
| 938 |
+
primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args
|
| 939 |
+
args.clear()
|
| 940 |
+
s37 = primals_8
|
| 941 |
+
s0 = primals_9
|
| 942 |
+
assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 943 |
+
assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 944 |
+
assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 945 |
+
assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 946 |
+
assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
|
| 947 |
+
assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
|
| 948 |
+
assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 949 |
+
assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
|
| 950 |
+
assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 951 |
+
assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
|
| 952 |
+
assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 953 |
+
assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 954 |
+
assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 955 |
+
assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
|
| 956 |
+
assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 957 |
+
with torch.cuda._DeviceGuard(7):
|
| 958 |
+
torch.cuda.set_device(7)
|
| 959 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 960 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 961 |
+
triton_per_fused_mul_0_xnumel = 32*s37
|
| 962 |
+
stream7 = get_raw_stream(7)
|
| 963 |
+
triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream7)
|
| 964 |
+
del getitem
|
| 965 |
+
del tangents_2
|
| 966 |
+
buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 967 |
+
buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 968 |
+
buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 969 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 970 |
+
stream7 = get_raw_stream(7)
|
| 971 |
+
triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream7)
|
| 972 |
+
del buf1
|
| 973 |
+
del getitem_1
|
| 974 |
+
del primals_10
|
| 975 |
+
del primals_11
|
| 976 |
+
del primals_12
|
| 977 |
+
del primals_13
|
| 978 |
+
del primals_14
|
| 979 |
+
del primals_15
|
| 980 |
+
del primals_16
|
| 981 |
+
del primals_2
|
| 982 |
+
del primals_4
|
| 983 |
+
del primals_6
|
| 984 |
+
del primals_7
|
| 985 |
+
del tangents_1
|
| 986 |
+
return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, )
|
| 987 |
+
|
| 988 |
+
runner = Runner(partitions=[])
|
| 989 |
+
call = runner.call
|
| 990 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 994 |
+
from torch._dynamo.testing import rand_strided
|
| 995 |
+
from torch._inductor.utils import print_performance
|
| 996 |
+
primals_8 = 80
|
| 997 |
+
primals_9 = 80
|
| 998 |
+
primals_2 = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 999 |
+
primals_4 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1000 |
+
primals_6 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1001 |
+
primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1002 |
+
primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1003 |
+
primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1004 |
+
primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1005 |
+
primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1006 |
+
primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1007 |
+
primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1008 |
+
primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1009 |
+
getitem = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1010 |
+
getitem_1 = rand_strided((1, 32, 80), (2560, 80, 1), device='cuda:7', dtype=torch.float32)
|
| 1011 |
+
tangents_1 = rand_strided((1, 32, 80, 128), (327680, 10240, 128, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1012 |
+
tangents_2 = rand_strided((1, 32, 80), (2560, 80, 1), device='cuda:7', dtype=torch.float32)
|
| 1013 |
+
fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2])
|
| 1014 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
if __name__ == "__main__":
|
| 1018 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 1019 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/44/3728d77fd47f8b1056ec8670d5b1bd262db03ae9994292fee6203d32e3d9cd03.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"}
|
progress/SpecForge/cache/compiled_kernels/44/c44m5klhlzg7nfvzfelnbb3hjh2jwzh2e5yyk3vtcvhyw6rbnjo6.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 524288, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_2 = r0_index
|
| 32 |
+
x5 = xindex
|
| 33 |
+
x1 = xindex // 128
|
| 34 |
+
x0 = (xindex % 128)
|
| 35 |
+
x3 = ((xindex // 128) % ks0)
|
| 36 |
+
x4 = xindex // ks1
|
| 37 |
+
tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
|
| 38 |
+
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
|
| 39 |
+
tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
|
| 40 |
+
tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
|
| 41 |
+
tmp2 = float("-inf")
|
| 42 |
+
tmp3 = tmp1 == tmp2
|
| 43 |
+
tmp5 = tmp4 - tmp1
|
| 44 |
+
tmp6 = 0.0
|
| 45 |
+
tmp7 = tl.where(tmp3, tmp6, tmp5)
|
| 46 |
+
tmp8 = libdevice.exp2(tmp7)
|
| 47 |
+
tmp9 = tmp0 * tmp8
|
| 48 |
+
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
|
| 49 |
+
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp14 = 1.0
|
| 51 |
+
tmp15 = tl.where(tmp3, tmp14, tmp13)
|
| 52 |
+
tmp16 = (tmp12 / tmp15)
|
| 53 |
+
tmp17 = tmp16.to(tl.float32)
|
| 54 |
+
tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
|
progress/SpecForge/cache/compiled_kernels/4d/c4d7fh2egdfps7aogbncwlp3ihfwtff243bbobq7vrxj2m2grl64.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 32768, 'r0_': 128},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 128
|
| 20 |
+
rnumel = r0_numel
|
| 21 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 22 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 23 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 24 |
+
xmask = xindex < xnumel
|
| 25 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 26 |
+
rbase = r0_base
|
| 27 |
+
x0 = (xindex % ks0)
|
| 28 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 29 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 30 |
+
x3 = xindex
|
| 31 |
+
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
| 32 |
+
r0_index = r0_offset + r0_base
|
| 33 |
+
r0_mask = r0_index < r0_numel
|
| 34 |
+
roffset = r0_offset
|
| 35 |
+
rindex = r0_index
|
| 36 |
+
r0_2 = r0_index
|
| 37 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 38 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 39 |
+
tmp2 = tmp0 * tmp1
|
| 40 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 41 |
+
tmp5 = _tmp4 + tmp3
|
| 42 |
+
_tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
|
| 43 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 44 |
+
tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 45 |
+
tmp6 = tmp4.to(tl.float32)
|
| 46 |
+
tmp8 = 0.6931471805599453
|
| 47 |
+
tmp9 = tmp7 * tmp8
|
| 48 |
+
tmp10 = 1.4426950408889634
|
| 49 |
+
tmp11 = tmp9 * tmp10
|
| 50 |
+
tmp12 = tmp6 - tmp11
|
| 51 |
+
tl.store(out_ptr1 + (x3), tmp12, xmask)
|
progress/SpecForge/cache/compiled_kernels/4d/fd68b3c1a3fd19883dc58697393b6044e6217afda9ea11f84bd620545197dd6b.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 40, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"}
|
progress/SpecForge/cache/compiled_kernels/4h/c4h32peoig2erjdxibxrq3sbpm533ci3z57ntqjhdemzxp2rhysl.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = ks0
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = ks1
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 1
|
| 148 |
+
stride_kv_idx_h = 1
|
| 149 |
+
stride_kv_idx_m = 1
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 1
|
| 245 |
+
stride_q_idx_h = 1
|
| 246 |
+
stride_q_idx_n = 1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = ks0
|
| 385 |
+
KV_LEN = ks1
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = ks0
|
| 578 |
+
KV_LEN = ks1
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/4k/c4korm4huj2wookuw6gikboxrsp3m5yt45c7fxucyujswm5fgb3u.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = 1
|
| 130 |
+
stride_kv_idx_h = 1
|
| 131 |
+
stride_kv_idx_m = 1
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/4x/31f9d1ee4882fe2005f02592ea2d9f20a1835b42c5baefd7795e8640f97fdc16.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
|
progress/SpecForge/cache/compiled_kernels/4x/c4xjhgyzut6anhrjeinspoinohfxvyl6skr4gd3vfrscrvsevmya.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 32768},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 26 |
+
tmp1 = 0.6931471805599453
|
| 27 |
+
tmp2 = tmp0 * tmp1
|
| 28 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/5b/c5blvz5sxoj2veuexokuub2zm2pg4l2nqbbny4rr2jhsiiyw6njy.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = 1
|
| 130 |
+
stride_kv_idx_h = 1
|
| 131 |
+
stride_kv_idx_m = 1
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/5g/c5g7nnbi3zupsx7kdee2ed6g2fgrtd2jxyggsjpckfg5p7rps4qm.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = 1
|
| 130 |
+
stride_kv_idx_h = 1
|
| 131 |
+
stride_kv_idx_m = 1
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/5j/3cc65a0fdb544c73efb7240355b77da3f1ab394b46f272fa923c368e6cc63c34.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 96, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"}
|
progress/SpecForge/cache/compiled_kernels/5j/c5j7yk5hlaaxs42qwjlmoczwtoukaw2dio2o6p7qfekdy5upikyv.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 524288, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_2 = r0_index
|
| 32 |
+
x5 = xindex
|
| 33 |
+
x1 = xindex // 128
|
| 34 |
+
x0 = (xindex % 128)
|
| 35 |
+
x3 = ((xindex // 128) % ks0)
|
| 36 |
+
x4 = xindex // ks1
|
| 37 |
+
tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
|
| 38 |
+
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
|
| 39 |
+
tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
|
| 40 |
+
tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
|
| 41 |
+
tmp2 = float("-inf")
|
| 42 |
+
tmp3 = tmp1 == tmp2
|
| 43 |
+
tmp5 = tmp4 - tmp1
|
| 44 |
+
tmp6 = 0.0
|
| 45 |
+
tmp7 = tl.where(tmp3, tmp6, tmp5)
|
| 46 |
+
tmp8 = libdevice.exp2(tmp7)
|
| 47 |
+
tmp9 = tmp0 * tmp8
|
| 48 |
+
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
|
| 49 |
+
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp14 = 1.0
|
| 51 |
+
tmp15 = tl.where(tmp3, tmp14, tmp13)
|
| 52 |
+
tmp16 = (tmp12 / tmp15)
|
| 53 |
+
tmp17 = tmp16.to(tl.float32)
|
| 54 |
+
tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
|
progress/SpecForge/cache/compiled_kernels/5w/c5wutjfcact264ykgcamj2asvz4eqe3ygz47upjgib2qw5rnnihu.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['4_backward']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ns/cnsdsjpw2wa2anwrjd5vp4pkfq4mtbebhfmyzeol2hd4mzwu6qnk.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# Graph fragment:
|
| 41 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem]
|
| 42 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
|
| 43 |
+
# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
|
| 44 |
+
# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=tangents_2]
|
| 45 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 46 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 47 |
+
# return %buf0,%buf1
|
| 48 |
+
triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', '''
|
| 49 |
+
import triton
|
| 50 |
+
import triton.language as tl
|
| 51 |
+
|
| 52 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 53 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 54 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 55 |
+
triton_helpers.set_driver_to_gpu()
|
| 56 |
+
|
| 57 |
+
@triton_heuristics.persistent_reduction(
|
| 58 |
+
size_hints={'x': 4096, 'r0_': 128},
|
| 59 |
+
reduction_hint=ReductionHint.INNER,
|
| 60 |
+
filename=__file__,
|
| 61 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 62 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 63 |
+
)
|
| 64 |
+
@triton.jit
|
| 65 |
+
def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 66 |
+
r0_numel = 128
|
| 67 |
+
R0_BLOCK: tl.constexpr = 128
|
| 68 |
+
rnumel = r0_numel
|
| 69 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 70 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 71 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 72 |
+
xmask = xindex < xnumel
|
| 73 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 74 |
+
r0_offset = 0
|
| 75 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 76 |
+
roffset = r0_offset
|
| 77 |
+
rindex = r0_index
|
| 78 |
+
r0_2 = r0_index
|
| 79 |
+
x0 = (xindex % ks0)
|
| 80 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 81 |
+
x3 = xindex
|
| 82 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
|
| 83 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
|
| 84 |
+
tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 85 |
+
tmp2 = tmp0 * tmp1
|
| 86 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 87 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 88 |
+
tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
|
| 89 |
+
tmp7 = tmp6.to(tl.float32)
|
| 90 |
+
tmp9 = 0.6931471805599453
|
| 91 |
+
tmp10 = tmp8 * tmp9
|
| 92 |
+
tmp11 = 1.4426950408889634
|
| 93 |
+
tmp12 = tmp10 * tmp11
|
| 94 |
+
tmp13 = tmp7 - tmp12
|
| 95 |
+
tl.store(out_ptr1 + (x3), tmp13, xmask)
|
| 96 |
+
''', device_str='cuda')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py
|
| 100 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 101 |
+
# Source node to ATen node mapping:
|
| 102 |
+
# Graph fragment:
|
| 103 |
+
# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2]
|
| 104 |
+
# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4]
|
| 105 |
+
# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6]
|
| 106 |
+
# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1]
|
| 107 |
+
# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
|
| 108 |
+
# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1]
|
| 109 |
+
# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3]
|
| 110 |
+
# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5]
|
| 111 |
+
# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_10]
|
| 112 |
+
# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_7]
|
| 113 |
+
# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_13]
|
| 114 |
+
# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_14]
|
| 115 |
+
# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_11]
|
| 116 |
+
# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_12]
|
| 117 |
+
# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_15]
|
| 118 |
+
# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_16]
|
| 119 |
+
# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {})
|
| 120 |
+
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 121 |
+
# return %getitem_4
|
| 122 |
+
triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', '''
|
| 123 |
+
import triton
|
| 124 |
+
import triton.language as tl
|
| 125 |
+
|
| 126 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 127 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 128 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 129 |
+
|
| 130 |
+
@triton_heuristics.template(
|
| 131 |
+
|
| 132 |
+
num_stages=3,
|
| 133 |
+
num_warps=8,
|
| 134 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 135 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 136 |
+
|
| 137 |
+
)
|
| 138 |
+
@triton.jit
|
| 139 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1):
|
| 140 |
+
PRESCALE_QK : tl.constexpr = False
|
| 141 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 142 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 143 |
+
WRITE_DQ : tl.constexpr = True
|
| 144 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 145 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 146 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 147 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 148 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 149 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 150 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 151 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 152 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 153 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 154 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 155 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 156 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 157 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 158 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 159 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 160 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 161 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 162 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 163 |
+
Q = arg_Q
|
| 164 |
+
K = arg_K
|
| 165 |
+
V = arg_V
|
| 166 |
+
LSE = arg_LSE
|
| 167 |
+
DELTA = arg_DELTA
|
| 168 |
+
DO = arg_DO
|
| 169 |
+
DQ = arg_DQ
|
| 170 |
+
DV = arg_DV
|
| 171 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 172 |
+
KV_IDX = arg_KV_IDX
|
| 173 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 174 |
+
Q_IDX = arg_Q_IDX
|
| 175 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 176 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 177 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 178 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 179 |
+
|
| 180 |
+
# Sub notation for this kernel:
|
| 181 |
+
#
|
| 182 |
+
# Q: Query, K: Key, V: Value
|
| 183 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 184 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 185 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 186 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 187 |
+
# inductor codegen
|
| 188 |
+
# M: Number of queries, N: Number of keys/values
|
| 189 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 190 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 191 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 192 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 193 |
+
# (Modifiable) Performance tuning options
|
| 194 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 195 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 196 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 197 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 198 |
+
#
|
| 199 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 200 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 201 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 202 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 203 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 204 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 205 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 206 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 207 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 208 |
+
|
| 209 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 210 |
+
# or involve a numerics vs. perf tradeoff
|
| 211 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 212 |
+
# about 20% more numerical error, but slightly faster.
|
| 213 |
+
|
| 214 |
+
# Define strides of inputs
|
| 215 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
|
| 216 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1
|
| 217 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1
|
| 218 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
|
| 219 |
+
|
| 220 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
|
| 221 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1
|
| 222 |
+
|
| 223 |
+
ZQ = 1
|
| 224 |
+
HQ = 32
|
| 225 |
+
HKV = 8
|
| 226 |
+
Q_LEN = ks0
|
| 227 |
+
ZKV = 1
|
| 228 |
+
KV_LEN = ks1
|
| 229 |
+
|
| 230 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 231 |
+
|
| 232 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 233 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 234 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 235 |
+
|
| 236 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 237 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 238 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 239 |
+
|
| 240 |
+
SPARSE_Z = 1
|
| 241 |
+
SPARSE_HQ = 1
|
| 242 |
+
|
| 243 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 244 |
+
|
| 245 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 246 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 247 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 248 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 249 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 250 |
+
|
| 251 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 252 |
+
K += k_adj
|
| 253 |
+
V += v_adj
|
| 254 |
+
DV += dv_adj
|
| 255 |
+
|
| 256 |
+
RCP_LN2 = 1.44269504
|
| 257 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 258 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 259 |
+
|
| 260 |
+
if pid >= NUM_KV_BLOCKS:
|
| 261 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 262 |
+
# THIS BLOCK DOES DQ
|
| 263 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 264 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 265 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 266 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 267 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 268 |
+
stride_kv_num_blks_h = 1
|
| 269 |
+
stride_kv_idx_h = 1
|
| 270 |
+
stride_kv_idx_m = 1
|
| 271 |
+
|
| 272 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 273 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 274 |
+
|
| 275 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 276 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 277 |
+
|
| 278 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 279 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 280 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 281 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 282 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 283 |
+
|
| 284 |
+
Q2 = Q + q_adj2
|
| 285 |
+
DO2 = DO + do_adj2
|
| 286 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 287 |
+
# if Q is broadcasted)
|
| 288 |
+
DQ2 = DQ + dq_adj2
|
| 289 |
+
LSE2 = LSE + off_chz2
|
| 290 |
+
DELTA2 = DELTA + off_chz2
|
| 291 |
+
|
| 292 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 293 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 294 |
+
|
| 295 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 296 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 297 |
+
|
| 298 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 299 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 300 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 301 |
+
|
| 302 |
+
if PRESCALE_QK:
|
| 303 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 304 |
+
|
| 305 |
+
if IS_DIVISIBLE:
|
| 306 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 307 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 308 |
+
else:
|
| 309 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 310 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 311 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 312 |
+
lse = lse[:, None]
|
| 313 |
+
|
| 314 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 315 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 316 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 317 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 318 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 319 |
+
|
| 320 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 321 |
+
dq = bwd_dq_inner(
|
| 322 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 323 |
+
K, V,
|
| 324 |
+
dq, q, do, Di, lse,
|
| 325 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 326 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 327 |
+
kv_indices, sparse_kv_num_blocks,
|
| 328 |
+
MATMUL_PRECISION,
|
| 329 |
+
IS_FULL_BLOCKS=False,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if HAS_FULL_BLOCKS:
|
| 333 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 334 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 335 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 336 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 337 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 338 |
+
|
| 339 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 340 |
+
dq = bwd_dq_inner(
|
| 341 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 342 |
+
K, V,
|
| 343 |
+
dq, q, do, Di, lse,
|
| 344 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 345 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 346 |
+
kv_indices, sparse_kv_num_blocks,
|
| 347 |
+
MATMUL_PRECISION,
|
| 348 |
+
IS_FULL_BLOCKS=True,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Write back dQ.
|
| 352 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 353 |
+
dq *= SM_SCALE
|
| 354 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 355 |
+
tl.store(dq_ptrs, dq)
|
| 356 |
+
else:
|
| 357 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 358 |
+
else:
|
| 359 |
+
# THIS BLOCK DOES DK & DV
|
| 360 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 361 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 362 |
+
|
| 363 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 364 |
+
|
| 365 |
+
stride_q_num_blks_h = 1
|
| 366 |
+
stride_q_idx_h = 1
|
| 367 |
+
stride_q_idx_n = 1
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 371 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 372 |
+
|
| 373 |
+
start_n1 = pid * BLOCK_N1
|
| 374 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 375 |
+
|
| 376 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 377 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 378 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 379 |
+
|
| 380 |
+
if PRESCALE_QK:
|
| 381 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 382 |
+
|
| 383 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 384 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 385 |
+
|
| 386 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 387 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 388 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 389 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 390 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 391 |
+
|
| 392 |
+
Q1 = Q + q_adj1
|
| 393 |
+
DO1 = DO + do_adj1
|
| 394 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 395 |
+
# if Q is broadcasted)
|
| 396 |
+
LSE1 = LSE + off_chz1
|
| 397 |
+
DELTA1 = DELTA + off_chz1
|
| 398 |
+
|
| 399 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 400 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 401 |
+
|
| 402 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 403 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 404 |
+
|
| 405 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 406 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 407 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 408 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 409 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 410 |
+
|
| 411 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 412 |
+
dk, dv = bwd_dkdv_inner(
|
| 413 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 414 |
+
Q1, DO1, DELTA1, LSE1,
|
| 415 |
+
dk, dv, k, v,
|
| 416 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 417 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 418 |
+
q_indices, sparse_q_num_blocks,
|
| 419 |
+
MATMUL_PRECISION,
|
| 420 |
+
IS_FULL_BLOCKS=False,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if HAS_FULL_BLOCKS:
|
| 425 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 426 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 427 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 428 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 429 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 430 |
+
|
| 431 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 432 |
+
dk, dv = bwd_dkdv_inner(
|
| 433 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 434 |
+
Q1, DO1, DELTA1, LSE1,
|
| 435 |
+
dk, dv, k, v,
|
| 436 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 437 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 438 |
+
q_indices, sparse_q_num_blocks,
|
| 439 |
+
MATMUL_PRECISION,
|
| 440 |
+
IS_FULL_BLOCKS=True,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Write back dV and dK.
|
| 444 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 445 |
+
|
| 446 |
+
index_n = offs_n1[:, None]
|
| 447 |
+
index_k = offs_k[None, :]
|
| 448 |
+
index_v = offs_v[None, :]
|
| 449 |
+
|
| 450 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 451 |
+
tl.store(dv_ptrs, dv)
|
| 452 |
+
else:
|
| 453 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 454 |
+
|
| 455 |
+
dk *= SM_SCALE
|
| 456 |
+
|
| 457 |
+
if SAFE_HEAD_DIM:
|
| 458 |
+
mask = index_n < KV_LEN
|
| 459 |
+
else:
|
| 460 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 461 |
+
|
| 462 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 463 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 464 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 465 |
+
xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
|
| 466 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 467 |
+
|
| 468 |
+
@triton.jit
|
| 469 |
+
def bwd_dq_inner(
|
| 470 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 471 |
+
K, V, # pointers
|
| 472 |
+
dq, q, do, Di, lse,
|
| 473 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 474 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 475 |
+
kv_indices, sparse_kv_num_blocks,
|
| 476 |
+
MATMUL_PRECISION,
|
| 477 |
+
IS_FULL_BLOCKS,
|
| 478 |
+
):
|
| 479 |
+
PRESCALE_QK : tl.constexpr = False
|
| 480 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 481 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 482 |
+
WRITE_DQ : tl.constexpr = True
|
| 483 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 484 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 485 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 486 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 487 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 488 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 489 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 490 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 491 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 492 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 493 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 494 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 495 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 496 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 497 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 498 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 499 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 500 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 501 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 502 |
+
|
| 503 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 504 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 505 |
+
Q_LEN = ks0
|
| 506 |
+
KV_LEN = ks1
|
| 507 |
+
|
| 508 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 509 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 510 |
+
|
| 511 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 512 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 513 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 514 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 515 |
+
|
| 516 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 517 |
+
|
| 518 |
+
for start_n in range(0, hi):
|
| 519 |
+
dq = bwd_dq_block_mn(
|
| 520 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 521 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 522 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 523 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 524 |
+
kv_indices, sparse_kv_num_blocks,
|
| 525 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 526 |
+
IS_FULL_BLOCKS,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# Increment pointers.
|
| 530 |
+
offset = get_offset_for_next_block(
|
| 531 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 532 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
kT_ptrs += offset * stride_kn
|
| 536 |
+
vT_ptrs += offset * stride_vn
|
| 537 |
+
|
| 538 |
+
offs_n2 += offset
|
| 539 |
+
|
| 540 |
+
return dq
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@triton.jit
|
| 544 |
+
def bwd_dq_block_mn(
|
| 545 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 546 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 547 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 548 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 549 |
+
kv_indices, sparse_kv_num_blocks,
|
| 550 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 551 |
+
IS_FULL_BLOCKS,
|
| 552 |
+
):
|
| 553 |
+
PRESCALE_QK : tl.constexpr = False
|
| 554 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 555 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 556 |
+
WRITE_DQ : tl.constexpr = True
|
| 557 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 558 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 559 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 560 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 561 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 562 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 563 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 564 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 567 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 568 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 569 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 570 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 571 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 572 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 573 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 574 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 575 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# NB reversed order to since K is transposed
|
| 579 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 580 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 581 |
+
if not PRESCALE_QK:
|
| 582 |
+
qk *= SM_SCALE
|
| 583 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 584 |
+
pre_mod_scores = qk
|
| 585 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 586 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 587 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 588 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 589 |
+
|
| 590 |
+
tmp0 = (qk)
|
| 591 |
+
post_mod_scores = tmp0
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
if not IS_DIVISIBLE:
|
| 597 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 598 |
+
|
| 599 |
+
if not IS_FULL_BLOCKS:
|
| 600 |
+
tmp1 = (m)
|
| 601 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 602 |
+
tmp3 = tmp1 < tmp2
|
| 603 |
+
tmp4 = (n)
|
| 604 |
+
tmp5 = tmp4 <= tmp1
|
| 605 |
+
tmp6 = tmp3 & tmp5
|
| 606 |
+
tmp7 = tmp1 >= tmp2
|
| 607 |
+
tmp8 = tmp4 < tmp2
|
| 608 |
+
tmp9 = tmp7 & tmp8
|
| 609 |
+
tmp10 = tmp8 == 0
|
| 610 |
+
tmp11 = tmp7 & tmp10
|
| 611 |
+
tmp12 = tmp1 - tmp2
|
| 612 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 613 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 614 |
+
tmp15 = tmp4 - tmp2
|
| 615 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 616 |
+
tmp17 = tmp14 == tmp16
|
| 617 |
+
tmp18 = tmp11 & tmp17
|
| 618 |
+
tmp19 = tmp9 | tmp18
|
| 619 |
+
tmp20 = tmp6 | tmp19
|
| 620 |
+
mask_mod_output = tmp20
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
# apply mask for partial masked block
|
| 624 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 625 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 626 |
+
if not PRESCALE_QK:
|
| 627 |
+
post_mod_scores *= RCP_LN2
|
| 628 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 629 |
+
# Compute dP and dS.
|
| 630 |
+
# NB reversed order to since V is transposed
|
| 631 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 632 |
+
|
| 633 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 634 |
+
ds = p * (dp - Di[:, None])
|
| 635 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 636 |
+
tmp21 = (ds)
|
| 637 |
+
grad_scores = tmp21
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
if not IS_DIVISIBLE:
|
| 641 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 642 |
+
|
| 643 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 644 |
+
if WRITE_DQ:
|
| 645 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 646 |
+
|
| 647 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 648 |
+
ds = grad_scores
|
| 649 |
+
|
| 650 |
+
if not IS_FULL_BLOCKS:
|
| 651 |
+
# (grads) apply mask for partially unmasked block
|
| 652 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 653 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 654 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 655 |
+
# Compute dQ.
|
| 656 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 657 |
+
|
| 658 |
+
return dq
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@triton.jit
|
| 662 |
+
def bwd_dkdv_inner(
|
| 663 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 664 |
+
Q, DO, DELTA, LSE, # pointers
|
| 665 |
+
dk, dv, k, v,
|
| 666 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 667 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 668 |
+
q_indices, sparse_q_num_blocks,
|
| 669 |
+
MATMUL_PRECISION,
|
| 670 |
+
IS_FULL_BLOCKS,
|
| 671 |
+
):
|
| 672 |
+
PRESCALE_QK : tl.constexpr = False
|
| 673 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 674 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 675 |
+
WRITE_DQ : tl.constexpr = True
|
| 676 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 677 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 678 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 679 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 680 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 681 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 682 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 683 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 684 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 685 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 686 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 687 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 688 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 689 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 690 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 691 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 692 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 693 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 694 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 695 |
+
|
| 696 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 697 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 698 |
+
Q_LEN = ks0
|
| 699 |
+
KV_LEN = ks1
|
| 700 |
+
|
| 701 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 702 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 703 |
+
|
| 704 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 705 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 706 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 707 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 708 |
+
|
| 709 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 710 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 711 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 712 |
+
|
| 713 |
+
for start_m in range(0, hi):
|
| 714 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 715 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 716 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 717 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 718 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 719 |
+
q_indices, sparse_q_num_blocks,
|
| 720 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 721 |
+
IS_FULL_BLOCKS,
|
| 722 |
+
)
|
| 723 |
+
# Increment pointers.
|
| 724 |
+
offset = get_offset_for_next_block(
|
| 725 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 726 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
qT_ptrs += offset * stride_qm
|
| 730 |
+
do_ptrs += offset * stride_dom
|
| 731 |
+
offs_m1 += offset
|
| 732 |
+
|
| 733 |
+
return dk, dv
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
@triton.jit
|
| 737 |
+
def bwd_dkdv_block_mn(
|
| 738 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1,
|
| 739 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 740 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 741 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 742 |
+
q_indices, sparse_q_num_blocks,
|
| 743 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 744 |
+
IS_FULL_BLOCKS,
|
| 745 |
+
):
|
| 746 |
+
PRESCALE_QK : tl.constexpr = False
|
| 747 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 749 |
+
WRITE_DQ : tl.constexpr = True
|
| 750 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 751 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 752 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 753 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 754 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 755 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 756 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 757 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 758 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 759 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 760 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 761 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 762 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 763 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 764 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 765 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 766 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 767 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 768 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# NB reversed order since Q is transposed
|
| 772 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 773 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 774 |
+
if IS_DIVISIBLE:
|
| 775 |
+
lse = tl.load(LSE + offs_m1)
|
| 776 |
+
else:
|
| 777 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 778 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 779 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 780 |
+
if not PRESCALE_QK:
|
| 781 |
+
qkT *= SM_SCALE
|
| 782 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 783 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 784 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 785 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 786 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 787 |
+
|
| 788 |
+
pre_mod_scores = qkT
|
| 789 |
+
tmp22 = (qkT)
|
| 790 |
+
post_mod_scores = tmp22
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
if not IS_DIVISIBLE:
|
| 795 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 796 |
+
|
| 797 |
+
if not IS_FULL_BLOCKS:
|
| 798 |
+
tmp23 = (m)
|
| 799 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 800 |
+
tmp25 = tmp23 < tmp24
|
| 801 |
+
tmp26 = (n)
|
| 802 |
+
tmp27 = tmp26 <= tmp23
|
| 803 |
+
tmp28 = tmp25 & tmp27
|
| 804 |
+
tmp29 = tmp23 >= tmp24
|
| 805 |
+
tmp30 = tmp26 < tmp24
|
| 806 |
+
tmp31 = tmp29 & tmp30
|
| 807 |
+
tmp32 = tmp30 == 0
|
| 808 |
+
tmp33 = tmp29 & tmp32
|
| 809 |
+
tmp34 = tmp23 - tmp24
|
| 810 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 811 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 812 |
+
tmp37 = tmp26 - tmp24
|
| 813 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 814 |
+
tmp39 = tmp36 == tmp38
|
| 815 |
+
tmp40 = tmp33 & tmp39
|
| 816 |
+
tmp41 = tmp31 | tmp40
|
| 817 |
+
tmp42 = tmp28 | tmp41
|
| 818 |
+
mask_mod_output = tmp42
|
| 819 |
+
|
| 820 |
+
# (grads) apply mask for fully masked block
|
| 821 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 822 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 823 |
+
if not PRESCALE_QK:
|
| 824 |
+
post_mod_scores *= RCP_LN2
|
| 825 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 826 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 827 |
+
# Compute dV.
|
| 828 |
+
ppT = pT
|
| 829 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 830 |
+
if IS_DIVISIBLE:
|
| 831 |
+
Di = tl.load(DELTA + offs_m1)
|
| 832 |
+
else:
|
| 833 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 834 |
+
# Compute dP and dS.
|
| 835 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 836 |
+
dsT = pT * (dpT - Di[None, :])
|
| 837 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 838 |
+
tmp43 = (dsT)
|
| 839 |
+
grad_scores = tmp43
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
if not IS_DIVISIBLE:
|
| 844 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 845 |
+
|
| 846 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 847 |
+
if not WRITE_DQ:
|
| 848 |
+
idx_b = off_z
|
| 849 |
+
idx_h = off_hq
|
| 850 |
+
idx_m = m
|
| 851 |
+
idx_n = n
|
| 852 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 853 |
+
|
| 854 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 855 |
+
dsT = grad_scores
|
| 856 |
+
if not IS_FULL_BLOCKS:
|
| 857 |
+
# (grads) apply mask for partially unmasked block
|
| 858 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 859 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 860 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 861 |
+
|
| 862 |
+
return dk, dv
|
| 863 |
+
|
| 864 |
+
# Utility triton funcs
|
| 865 |
+
@triton.jit
|
| 866 |
+
def get_offset_for_next_block(
|
| 867 |
+
loop_iter, col_indices, total_blocks,
|
| 868 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 869 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 870 |
+
):
|
| 871 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 872 |
+
return BLOCK
|
| 873 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 874 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 875 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 876 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 877 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 878 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 879 |
+
return offset
|
| 880 |
+
|
| 881 |
+
@triton.jit
|
| 882 |
+
def get_bounded_indices(indices, max_len=None):
|
| 883 |
+
return indices % max_len if max_len is not None else indices
|
| 884 |
+
|
| 885 |
+
@triton.jit
|
| 886 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 887 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 888 |
+
return tl.load(block_ptr)
|
| 889 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 890 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 891 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 892 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 893 |
+
else:
|
| 894 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 895 |
+
|
| 896 |
+
@triton.jit
|
| 897 |
+
def load_checked_2d(
|
| 898 |
+
ptr,
|
| 899 |
+
offs_m,
|
| 900 |
+
offs_n,
|
| 901 |
+
stride_m,
|
| 902 |
+
stride_n,
|
| 903 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 904 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 905 |
+
M_LEN: tl.constexpr,
|
| 906 |
+
N_LEN: tl.constexpr,
|
| 907 |
+
):
|
| 908 |
+
# Calculate final pointer if strides are provided
|
| 909 |
+
if stride_m is not None and stride_n is not None:
|
| 910 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 911 |
+
|
| 912 |
+
# Handle all masking cases
|
| 913 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 914 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 915 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 916 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 917 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 918 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 919 |
+
else: # Both divisible
|
| 920 |
+
return tl.load(ptr)
|
| 921 |
+
''', device_str='cuda')
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
async_compile.wait(globals())
|
| 925 |
+
del async_compile
|
| 926 |
+
|
| 927 |
+
class Runner:
|
| 928 |
+
def __init__(self, partitions):
|
| 929 |
+
self.partitions = partitions
|
| 930 |
+
|
| 931 |
+
def recursively_apply_fns(self, fns):
|
| 932 |
+
new_callables = []
|
| 933 |
+
for fn, c in zip(fns, self.partitions):
|
| 934 |
+
new_callables.append(fn(c))
|
| 935 |
+
self.partitions = new_callables
|
| 936 |
+
|
| 937 |
+
def call(self, args):
|
| 938 |
+
primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args
|
| 939 |
+
args.clear()
|
| 940 |
+
s37 = primals_8
|
| 941 |
+
s0 = primals_9
|
| 942 |
+
assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 943 |
+
assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 944 |
+
assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 945 |
+
assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 946 |
+
assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1))
|
| 947 |
+
assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1))
|
| 948 |
+
assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 949 |
+
assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1))
|
| 950 |
+
assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 951 |
+
assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1))
|
| 952 |
+
assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 953 |
+
assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 954 |
+
assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 955 |
+
assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
|
| 956 |
+
assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1))
|
| 957 |
+
with torch.cuda._DeviceGuard(7):
|
| 958 |
+
torch.cuda.set_device(7)
|
| 959 |
+
buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 960 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 961 |
+
triton_per_fused_mul_0_xnumel = 32*s37
|
| 962 |
+
stream7 = get_raw_stream(7)
|
| 963 |
+
triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream7)
|
| 964 |
+
del getitem
|
| 965 |
+
del tangents_2
|
| 966 |
+
buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 967 |
+
buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 968 |
+
buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16)
|
| 969 |
+
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul]
|
| 970 |
+
stream7 = get_raw_stream(7)
|
| 971 |
+
triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream7)
|
| 972 |
+
del buf1
|
| 973 |
+
del getitem_1
|
| 974 |
+
del primals_10
|
| 975 |
+
del primals_11
|
| 976 |
+
del primals_12
|
| 977 |
+
del primals_13
|
| 978 |
+
del primals_14
|
| 979 |
+
del primals_15
|
| 980 |
+
del primals_16
|
| 981 |
+
del primals_2
|
| 982 |
+
del primals_4
|
| 983 |
+
del primals_6
|
| 984 |
+
del primals_7
|
| 985 |
+
del tangents_1
|
| 986 |
+
return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, )
|
| 987 |
+
|
| 988 |
+
runner = Runner(partitions=[])
|
| 989 |
+
call = runner.call
|
| 990 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 994 |
+
from torch._dynamo.testing import rand_strided
|
| 995 |
+
from torch._inductor.utils import print_performance
|
| 996 |
+
primals_8 = 128
|
| 997 |
+
primals_9 = 128
|
| 998 |
+
primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 999 |
+
primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1000 |
+
primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1001 |
+
primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1002 |
+
primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1003 |
+
primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1004 |
+
primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1005 |
+
primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1006 |
+
primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1007 |
+
primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1008 |
+
primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 1009 |
+
getitem = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1010 |
+
getitem_1 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:7', dtype=torch.float32)
|
| 1011 |
+
tangents_1 = rand_strided((1, 32, 128, 128), (524288, 16384, 128, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 1012 |
+
tangents_2 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:7', dtype=torch.float32)
|
| 1013 |
+
fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2])
|
| 1014 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
if __name__ == "__main__":
|
| 1018 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 1019 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/5z/c5zh2j5k5rlsr5zd4tfbvoplpwmbtizbrldjq2hw4nndmjztlcuy.py
ADDED
|
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['3_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wp/cwpsdlduqm7qdaycs5k762qtpc4rkcap2vuyn5ripw5tdmb2n2ce.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1]
|
| 43 |
+
# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1]
|
| 44 |
+
# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1]
|
| 45 |
+
# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg9_1]
|
| 48 |
+
# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg6_1]
|
| 49 |
+
# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg10_1]
|
| 50 |
+
# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg11_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %buf2
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=2,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 80 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 81 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 82 |
+
SPLIT_KV : tl.constexpr = 32
|
| 83 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 84 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 86 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 87 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 88 |
+
BLOCK_M : tl.constexpr = 256
|
| 89 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 90 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 91 |
+
BLOCK_N : tl.constexpr = 64
|
| 92 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 93 |
+
USE_TMA : tl.constexpr = False
|
| 94 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 95 |
+
Q = arg_Q
|
| 96 |
+
K = arg_K
|
| 97 |
+
V = arg_V
|
| 98 |
+
M = arg_M
|
| 99 |
+
L = arg_L
|
| 100 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 101 |
+
KV_IDX = arg_KV_IDX
|
| 102 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 103 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 104 |
+
|
| 105 |
+
# Sub notation for this kernel:
|
| 106 |
+
# Q: Query, K: Key, V: Value
|
| 107 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 108 |
+
# M: Number of queries, N: Number of keys/values
|
| 109 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 110 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 111 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 112 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 113 |
+
# (Modifiable) Config options:
|
| 114 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 115 |
+
# TILE_KV: length of each local KV split
|
| 116 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 117 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 118 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 119 |
+
#
|
| 120 |
+
# change of base out of the loop
|
| 121 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 122 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 123 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 124 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 125 |
+
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 127 |
+
#
|
| 128 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 129 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 130 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 131 |
+
#
|
| 132 |
+
#
|
| 133 |
+
# Output: ACC output accumulated across local KV split.
|
| 134 |
+
|
| 135 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 136 |
+
|
| 137 |
+
# Define Q Strides
|
| 138 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
|
| 139 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 140 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 141 |
+
stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
|
| 142 |
+
stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
Z = 1
|
| 146 |
+
ZKV = 1
|
| 147 |
+
HKV = 8
|
| 148 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 149 |
+
HQ = HKV * G
|
| 150 |
+
Q_LEN = ks0
|
| 151 |
+
KV_LEN = ks1
|
| 152 |
+
|
| 153 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 154 |
+
|
| 155 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 156 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 157 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 158 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 159 |
+
|
| 160 |
+
off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
|
| 161 |
+
off_zkv = off_z % ZKV
|
| 162 |
+
off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
|
| 163 |
+
off_t = tl.program_id(1).to(INDEX_DTYPE)
|
| 164 |
+
|
| 165 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 166 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 167 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 168 |
+
|
| 169 |
+
K = K + k_offset
|
| 170 |
+
V = V + v_offset
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 176 |
+
sparse_idx_h = off_hkv % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 179 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 180 |
+
|
| 181 |
+
# initialize pointer to m and l
|
| 182 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 183 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 184 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 185 |
+
|
| 186 |
+
# initialize offsets
|
| 187 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 188 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 189 |
+
off_g = tl.arange(0, G) # [G]
|
| 190 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 191 |
+
offs_hq = offs_g + off_hkv * G
|
| 192 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 193 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 194 |
+
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 195 |
+
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 196 |
+
|
| 197 |
+
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
| 198 |
+
stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
|
| 199 |
+
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
| 200 |
+
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
|
| 201 |
+
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
| 202 |
+
|
| 203 |
+
# Calculate KV blocks that belong this CTA.
|
| 204 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 205 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 206 |
+
|
| 207 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 208 |
+
|
| 209 |
+
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 210 |
+
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
| 211 |
+
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 212 |
+
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
| 213 |
+
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
| 214 |
+
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
| 215 |
+
else:
|
| 216 |
+
q = tl.load(Q + q_offset + q_range)
|
| 217 |
+
|
| 218 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 222 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 223 |
+
# Offset the kv_indices tensor by the correct batch and head
|
| 224 |
+
kv_indices = KV_IDX + sparse_idx_hz_offset
|
| 225 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
| 226 |
+
MAX_KV_IDX = 1
|
| 227 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 228 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 229 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 230 |
+
# first kv block we're loading
|
| 231 |
+
|
| 232 |
+
# last valid block according to sparse mask
|
| 233 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 234 |
+
|
| 235 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 236 |
+
|
| 237 |
+
desc_k = None
|
| 238 |
+
desc_v = None
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 242 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 243 |
+
# accumulatd values
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
#offsets
|
| 246 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 247 |
+
off_n,
|
| 248 |
+
#block sparse data
|
| 249 |
+
kv_indices, kv_num_blocks,
|
| 250 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 251 |
+
MATMUL_PRECISION,
|
| 252 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 253 |
+
IS_FULL_BLOCKS=False,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 258 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 259 |
+
# apply mask_mod to them - only score_mod
|
| 260 |
+
if HAS_FULL_BLOCKS:
|
| 261 |
+
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
| 262 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
| 263 |
+
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
| 264 |
+
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
| 265 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
| 266 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 267 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 268 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 269 |
+
|
| 270 |
+
# last valid block according to sparse mask
|
| 271 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 272 |
+
|
| 273 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 274 |
+
|
| 275 |
+
acc, l_i, m_i = forward_inner(
|
| 276 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 277 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 278 |
+
# accumulatd values
|
| 279 |
+
acc, l_i, m_i,
|
| 280 |
+
#offsets
|
| 281 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 282 |
+
off_n,
|
| 283 |
+
#block sparse data
|
| 284 |
+
kv_indices, kv_num_blocks,
|
| 285 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 286 |
+
MATMUL_PRECISION,
|
| 287 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 288 |
+
IS_FULL_BLOCKS=True,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 292 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 293 |
+
|
| 294 |
+
M_block_ptr = tl.make_block_ptr(
|
| 295 |
+
base=M + m_offset,
|
| 296 |
+
shape=(G, Q_LEN), # (G, M)
|
| 297 |
+
strides=(stride_mh, stride_mm),
|
| 298 |
+
offsets=(off_hkv*G, 0),
|
| 299 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 300 |
+
order=(1, 0)
|
| 301 |
+
)
|
| 302 |
+
L_block_ptr = tl.make_block_ptr(
|
| 303 |
+
base=L + l_offset,
|
| 304 |
+
shape=(G, Q_LEN), # (G, M)
|
| 305 |
+
strides=(stride_lh, stride_lm),
|
| 306 |
+
offsets=(off_hkv*G, 0),
|
| 307 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 308 |
+
order=(1, 0)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 312 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 313 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 314 |
+
if SAFE_M_BOUNDARY:
|
| 315 |
+
tl.store(M_block_ptr, m_i)
|
| 316 |
+
tl.store(L_block_ptr, l_i)
|
| 317 |
+
else:
|
| 318 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 319 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 320 |
+
|
| 321 |
+
# -- store output
|
| 322 |
+
idx_z = off_z
|
| 323 |
+
idx_t = off_t
|
| 324 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 325 |
+
idx_m = off_m[None, :, None]
|
| 326 |
+
idx_d = offs_vd[None, None, :]
|
| 327 |
+
|
| 328 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 329 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 330 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
|
| 331 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Utility triton funcs
|
| 335 |
+
@triton.jit
|
| 336 |
+
def get_offset_for_next_block(
|
| 337 |
+
loop_iter, col_indices, total_blocks,
|
| 338 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 339 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 340 |
+
):
|
| 341 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 342 |
+
return BLOCK
|
| 343 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 344 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 345 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 346 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 347 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 348 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 349 |
+
return offset
|
| 350 |
+
|
| 351 |
+
@triton.jit
|
| 352 |
+
def get_bounded_indices(indices, max_len=None):
|
| 353 |
+
return indices % max_len if max_len is not None else indices
|
| 354 |
+
|
| 355 |
+
@triton.jit
|
| 356 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 357 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 358 |
+
return tl.load(block_ptr)
|
| 359 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 360 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 361 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 362 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 363 |
+
else:
|
| 364 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 365 |
+
|
| 366 |
+
@triton.jit
|
| 367 |
+
def load_checked_2d(
|
| 368 |
+
ptr,
|
| 369 |
+
offs_m,
|
| 370 |
+
offs_n,
|
| 371 |
+
stride_m,
|
| 372 |
+
stride_n,
|
| 373 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 374 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 375 |
+
M_LEN: tl.constexpr,
|
| 376 |
+
N_LEN: tl.constexpr,
|
| 377 |
+
):
|
| 378 |
+
# Calculate final pointer if strides are provided
|
| 379 |
+
if stride_m is not None and stride_n is not None:
|
| 380 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 381 |
+
|
| 382 |
+
# Handle all masking cases
|
| 383 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 384 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 385 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 386 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 387 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 388 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 389 |
+
else: # Both divisible
|
| 390 |
+
return tl.load(ptr)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# Common Imports
|
| 394 |
+
@triton.jit
|
| 395 |
+
def forward_block_mn(
|
| 396 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 397 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 398 |
+
# accumulated values
|
| 399 |
+
acc, l_i, m_i,
|
| 400 |
+
# Offsets
|
| 401 |
+
off_z, off_h, offs_m, offs_n,
|
| 402 |
+
# Offsets needed for TMA loads
|
| 403 |
+
kv_start,
|
| 404 |
+
kv_offset,
|
| 405 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 406 |
+
# Strides for K and V
|
| 407 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 408 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 409 |
+
|
| 410 |
+
):
|
| 411 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 412 |
+
PRESCALE_QK : tl.constexpr = False
|
| 413 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 414 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 415 |
+
WRITE_DQ : tl.constexpr = True
|
| 416 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 417 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 418 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 419 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 420 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 421 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 422 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 423 |
+
SPLIT_KV : tl.constexpr = 32
|
| 424 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 425 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 426 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 427 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 428 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 429 |
+
BLOCK_M : tl.constexpr = 256
|
| 430 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 431 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 432 |
+
BLOCK_N : tl.constexpr = 64
|
| 433 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 434 |
+
USE_TMA : tl.constexpr = False
|
| 435 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# -- load k --
|
| 439 |
+
# NB reversed order to since K is transposed
|
| 440 |
+
kv_base_offset = kv_start + kv_offset
|
| 441 |
+
|
| 442 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 443 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 444 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 445 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 446 |
+
|
| 447 |
+
k = tl.trans(k)
|
| 448 |
+
# -- compute qk ---
|
| 449 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
qk *= SM_SCALE
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 454 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 455 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 456 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 457 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 458 |
+
|
| 459 |
+
tmp0 = (qk)
|
| 460 |
+
post_mod_scores = tmp0
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 464 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 465 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 466 |
+
|
| 467 |
+
if not IS_FULL_BLOCKS:
|
| 468 |
+
tmp1 = (m)
|
| 469 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 470 |
+
tmp3 = tmp1 < tmp2
|
| 471 |
+
tmp4 = (n)
|
| 472 |
+
tmp5 = tmp4 <= tmp1
|
| 473 |
+
tmp6 = tmp3 & tmp5
|
| 474 |
+
tmp7 = tmp1 >= tmp2
|
| 475 |
+
tmp8 = tmp4 < tmp2
|
| 476 |
+
tmp9 = tmp7 & tmp8
|
| 477 |
+
tmp10 = tmp8 == 0
|
| 478 |
+
tmp11 = tmp7 & tmp10
|
| 479 |
+
tmp12 = tmp1 - tmp2
|
| 480 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 481 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 482 |
+
tmp15 = tmp4 - tmp2
|
| 483 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 484 |
+
tmp17 = tmp14 == tmp16
|
| 485 |
+
tmp18 = tmp11 & tmp17
|
| 486 |
+
tmp19 = tmp9 | tmp18
|
| 487 |
+
tmp20 = tmp6 | tmp19
|
| 488 |
+
mask_mod_output = tmp20
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 492 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 493 |
+
# apply mask for partially unmasked blocks
|
| 494 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 495 |
+
|
| 496 |
+
if not PRESCALE_QK:
|
| 497 |
+
post_mod_scores *= RCP_LN2
|
| 498 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 499 |
+
|
| 500 |
+
# -- compute scaling constant ---
|
| 501 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 502 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 503 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 504 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 505 |
+
else:
|
| 506 |
+
m_ij_masked = m_ij
|
| 507 |
+
|
| 508 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 509 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 510 |
+
|
| 511 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 512 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 513 |
+
# m_ij
|
| 514 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 515 |
+
# # -- scale and update acc --
|
| 516 |
+
acc = acc * alpha[:, None]
|
| 517 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 518 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 519 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 520 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 521 |
+
|
| 522 |
+
# -- update m_i
|
| 523 |
+
m_i = m_ij
|
| 524 |
+
|
| 525 |
+
return acc, l_i, m_i
|
| 526 |
+
|
| 527 |
+
@triton.jit
|
| 528 |
+
def forward_inner(
|
| 529 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 530 |
+
q, K, V,
|
| 531 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 532 |
+
# accumulated values
|
| 533 |
+
acc, l_i, m_i,
|
| 534 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 535 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 536 |
+
off_z, off_h, offs_m, offs_n,
|
| 537 |
+
# Offsets needed for TMA loads
|
| 538 |
+
kv_start,
|
| 539 |
+
# blocksparse data
|
| 540 |
+
kv_indices, kv_num_blocks,
|
| 541 |
+
# start kv and end kv block
|
| 542 |
+
block_n_start, block_n_end,
|
| 543 |
+
MATMUL_PRECISION,
|
| 544 |
+
# Strides for K and V
|
| 545 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 546 |
+
IS_FULL_BLOCKS,
|
| 547 |
+
):
|
| 548 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 549 |
+
PRESCALE_QK : tl.constexpr = False
|
| 550 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 551 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 552 |
+
WRITE_DQ : tl.constexpr = True
|
| 553 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 554 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 555 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 556 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 557 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 558 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
SPLIT_KV : tl.constexpr = 32
|
| 561 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 562 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 563 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 565 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 566 |
+
BLOCK_M : tl.constexpr = 256
|
| 567 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 568 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 569 |
+
BLOCK_N : tl.constexpr = 64
|
| 570 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 571 |
+
USE_TMA : tl.constexpr = False
|
| 572 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
|
| 578 |
+
if PRESCALE_QK:
|
| 579 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 580 |
+
|
| 581 |
+
kv_offset = 0
|
| 582 |
+
|
| 583 |
+
# loop over k, v and update accumulator until block_n_end
|
| 584 |
+
for start_n in range(block_n_start, block_n_end):
|
| 585 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 586 |
+
if IS_DIVISIBLE:
|
| 587 |
+
acc, l_i, m_i = forward_block_mn(
|
| 588 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 589 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 590 |
+
# accumulated values
|
| 591 |
+
acc, l_i, m_i,
|
| 592 |
+
# Offsets
|
| 593 |
+
off_z, off_h, offs_m, offs_n,
|
| 594 |
+
# Offsets needed for TMA loads
|
| 595 |
+
kv_start,
|
| 596 |
+
kv_offset,
|
| 597 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 598 |
+
# Strides for K and V
|
| 599 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 604 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 605 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 606 |
+
# to the last block because it's faster a lot.
|
| 607 |
+
acc, l_i, m_i = forward_block_mn(
|
| 608 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 609 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 610 |
+
# accumulated values
|
| 611 |
+
acc, l_i, m_i,
|
| 612 |
+
# Offsets
|
| 613 |
+
off_z, off_h, offs_m, offs_n,
|
| 614 |
+
# Offsets needed for TMA loads
|
| 615 |
+
kv_start,
|
| 616 |
+
kv_offset,
|
| 617 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 618 |
+
# Strides for K and V
|
| 619 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 620 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
offset = get_offset_for_next_block(
|
| 626 |
+
start_n, kv_indices, kv_num_blocks,
|
| 627 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
offs_n = offs_n + offset
|
| 631 |
+
kv_offset += offset
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
return acc, l_i, m_i
|
| 635 |
+
''', device_str='cuda')
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/3p/c3pdrhexk4rwol7f5l5vh7n543dj6piq6gw5k66g2p4vlyhopnop.py
|
| 639 |
+
# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
|
| 640 |
+
# Source node to ATen node mapping:
|
| 641 |
+
# flex_attention => flex_attention
|
| 642 |
+
# lse_scaled => mul_9
|
| 643 |
+
# Graph fragment:
|
| 644 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 645 |
+
# %buf4 : Tensor = PlaceHolder[target=buf4]
|
| 646 |
+
# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
|
| 647 |
+
# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
|
| 648 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 649 |
+
# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 650 |
+
# return %buf5,%buf7,%mul_9
|
| 651 |
+
triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', '''
|
| 652 |
+
import triton
|
| 653 |
+
import triton.language as tl
|
| 654 |
+
|
| 655 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 656 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 657 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 658 |
+
triton_helpers.set_driver_to_gpu()
|
| 659 |
+
|
| 660 |
+
@triton_heuristics.persistent_reduction(
|
| 661 |
+
size_hints={'x': 2048, 'r0_': 32},
|
| 662 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 663 |
+
filename=__file__,
|
| 664 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 665 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 666 |
+
)
|
| 667 |
+
@triton.jit
|
| 668 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 669 |
+
r0_numel = 32
|
| 670 |
+
R0_BLOCK: tl.constexpr = 32
|
| 671 |
+
rnumel = r0_numel
|
| 672 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 673 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 674 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 675 |
+
xmask = xindex < xnumel
|
| 676 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 677 |
+
r0_offset = 0
|
| 678 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 679 |
+
roffset = r0_offset
|
| 680 |
+
rindex = r0_index
|
| 681 |
+
r0_1 = r0_index
|
| 682 |
+
x0 = xindex
|
| 683 |
+
x2 = (xindex % ks0)
|
| 684 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 685 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 686 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 687 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 688 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 689 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 690 |
+
tmp6 = float("-inf")
|
| 691 |
+
tmp7 = tmp4 == tmp6
|
| 692 |
+
tmp8 = tmp0 - tmp4
|
| 693 |
+
tmp9 = 0.0
|
| 694 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 695 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 696 |
+
tmp12 = tmp5 * tmp11
|
| 697 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 698 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 699 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 700 |
+
tmp17 = 1.0
|
| 701 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 702 |
+
tmp19 = libdevice.log2(tmp18)
|
| 703 |
+
tmp20 = tmp19 + tmp4
|
| 704 |
+
tmp21 = 0.6931471805599453
|
| 705 |
+
tmp22 = tmp20 * tmp21
|
| 706 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 707 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 708 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
| 709 |
+
''', device_str='cuda')
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/so/csovo4ylg3oadsjdsgqzxojai65n6xwa5n7magfct6j23x35wkeq.py
|
| 713 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 714 |
+
# Source node to ATen node mapping:
|
| 715 |
+
# flex_attention => flex_attention, getitem
|
| 716 |
+
# Graph fragment:
|
| 717 |
+
# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf2]
|
| 718 |
+
# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
|
| 719 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 720 |
+
# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf8]
|
| 721 |
+
# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
|
| 722 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 723 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {})
|
| 724 |
+
# return %buf8,%getitem
|
| 725 |
+
triton_per_fused_2 = async_compile.triton('triton_per_fused_2', '''
|
| 726 |
+
import triton
|
| 727 |
+
import triton.language as tl
|
| 728 |
+
|
| 729 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 730 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 731 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 732 |
+
triton_helpers.set_driver_to_gpu()
|
| 733 |
+
|
| 734 |
+
@triton_heuristics.persistent_reduction(
|
| 735 |
+
size_hints={'x': 262144, 'r0_': 32},
|
| 736 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 737 |
+
filename=__file__,
|
| 738 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
|
| 739 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 740 |
+
)
|
| 741 |
+
@triton.jit
|
| 742 |
+
def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 743 |
+
r0_numel = 32
|
| 744 |
+
R0_BLOCK: tl.constexpr = 32
|
| 745 |
+
rnumel = r0_numel
|
| 746 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 747 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 748 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 749 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 750 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 751 |
+
r0_offset = 0
|
| 752 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 753 |
+
roffset = r0_offset
|
| 754 |
+
rindex = r0_index
|
| 755 |
+
r0_2 = r0_index
|
| 756 |
+
x5 = xindex
|
| 757 |
+
x1 = xindex // 128
|
| 758 |
+
x0 = (xindex % 128)
|
| 759 |
+
x3 = ((xindex // 128) % ks0)
|
| 760 |
+
x4 = xindex // ks1
|
| 761 |
+
tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
|
| 762 |
+
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
|
| 763 |
+
tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
|
| 764 |
+
tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
|
| 765 |
+
tmp2 = float("-inf")
|
| 766 |
+
tmp3 = tmp1 == tmp2
|
| 767 |
+
tmp5 = tmp4 - tmp1
|
| 768 |
+
tmp6 = 0.0
|
| 769 |
+
tmp7 = tl.where(tmp3, tmp6, tmp5)
|
| 770 |
+
tmp8 = libdevice.exp2(tmp7)
|
| 771 |
+
tmp9 = tmp0 * tmp8
|
| 772 |
+
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
|
| 773 |
+
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
|
| 774 |
+
tmp14 = 1.0
|
| 775 |
+
tmp15 = tl.where(tmp3, tmp14, tmp13)
|
| 776 |
+
tmp16 = (tmp12 / tmp15)
|
| 777 |
+
tmp17 = tmp16.to(tl.float32)
|
| 778 |
+
tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
|
| 779 |
+
''', device_str='cuda')
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
async_compile.wait(globals())
|
| 783 |
+
del async_compile
|
| 784 |
+
|
| 785 |
+
class Runner:
|
| 786 |
+
def __init__(self, partitions):
|
| 787 |
+
self.partitions = partitions
|
| 788 |
+
|
| 789 |
+
def recursively_apply_fns(self, fns):
|
| 790 |
+
new_callables = []
|
| 791 |
+
for fn, c in zip(fns, self.partitions):
|
| 792 |
+
new_callables.append(fn(c))
|
| 793 |
+
self.partitions = new_callables
|
| 794 |
+
|
| 795 |
+
def call(self, args):
|
| 796 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
|
| 797 |
+
args.clear()
|
| 798 |
+
s50 = arg0_1
|
| 799 |
+
s0 = arg2_1
|
| 800 |
+
s43 = arg4_1
|
| 801 |
+
s37 = arg7_1
|
| 802 |
+
s71 = arg8_1
|
| 803 |
+
assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 804 |
+
assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 805 |
+
assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
|
| 806 |
+
assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 807 |
+
assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
|
| 808 |
+
assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
|
| 809 |
+
assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 810 |
+
assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
|
| 811 |
+
assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 812 |
+
assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
|
| 813 |
+
assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 814 |
+
with torch.cuda._DeviceGuard(7):
|
| 815 |
+
torch.cuda.set_device(7)
|
| 816 |
+
buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
|
| 817 |
+
buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
|
| 818 |
+
buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32)
|
| 819 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 820 |
+
stream7 = get_raw_stream(7)
|
| 821 |
+
triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream7)
|
| 822 |
+
del arg10_1
|
| 823 |
+
del arg11_1
|
| 824 |
+
del arg1_1
|
| 825 |
+
del arg3_1
|
| 826 |
+
del arg5_1
|
| 827 |
+
del arg6_1
|
| 828 |
+
del arg9_1
|
| 829 |
+
buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32)
|
| 830 |
+
buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 831 |
+
buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 832 |
+
# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
|
| 833 |
+
triton_per_fused_mul_1_xnumel = 32*s37
|
| 834 |
+
stream7 = get_raw_stream(7)
|
| 835 |
+
triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream7)
|
| 836 |
+
del buf1
|
| 837 |
+
ps0 = 128*s37
|
| 838 |
+
buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 839 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 840 |
+
triton_per_fused_2_xnumel = 4096*s37
|
| 841 |
+
stream7 = get_raw_stream(7)
|
| 842 |
+
triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream7)
|
| 843 |
+
del buf0
|
| 844 |
+
del buf2
|
| 845 |
+
del buf5
|
| 846 |
+
del buf7
|
| 847 |
+
return (buf9, buf10, )
|
| 848 |
+
|
| 849 |
+
runner = Runner(partitions=[])
|
| 850 |
+
call = runner.call
|
| 851 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 855 |
+
from torch._dynamo.testing import rand_strided
|
| 856 |
+
from torch._inductor.utils import print_performance
|
| 857 |
+
arg0_1 = 48
|
| 858 |
+
arg1_1 = rand_strided((1, 32, 48, 128), (196608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 859 |
+
arg2_1 = 48
|
| 860 |
+
arg3_1 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 861 |
+
arg4_1 = 48
|
| 862 |
+
arg5_1 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 863 |
+
arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 864 |
+
arg7_1 = 48
|
| 865 |
+
arg8_1 = 48
|
| 866 |
+
arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 867 |
+
arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 868 |
+
arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 869 |
+
arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 870 |
+
arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 871 |
+
arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 872 |
+
arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 873 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
|
| 874 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
if __name__ == "__main__":
|
| 878 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 879 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/66/c66vzxeqjq4tywx6ezsscs3u3rb6yxac26rzmkwrlzc3kmkcnhlf.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 36 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 37 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 38 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 39 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 42 |
+
Q = arg_Q
|
| 43 |
+
K = arg_K
|
| 44 |
+
V = arg_V
|
| 45 |
+
LSE = arg_LSE
|
| 46 |
+
DELTA = arg_DELTA
|
| 47 |
+
DO = arg_DO
|
| 48 |
+
DQ = arg_DQ
|
| 49 |
+
DV = arg_DV
|
| 50 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 51 |
+
KV_IDX = arg_KV_IDX
|
| 52 |
+
Q_NUM_BLKS = arg_Q_NUM_BLKS
|
| 53 |
+
Q_IDX = arg_Q_IDX
|
| 54 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 55 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 56 |
+
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
|
| 57 |
+
FULL_Q_IDX = arg_FULL_Q_IDX
|
| 58 |
+
|
| 59 |
+
# Sub notation for this kernel:
|
| 60 |
+
#
|
| 61 |
+
# Q: Query, K: Key, V: Value
|
| 62 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 63 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 64 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 65 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 66 |
+
# inductor codegen
|
| 67 |
+
# M: Number of queries, N: Number of keys/values
|
| 68 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 69 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 70 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 71 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 72 |
+
# (Modifiable) Performance tuning options
|
| 73 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 74 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 75 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 76 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 77 |
+
#
|
| 78 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 79 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 80 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 81 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 82 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 83 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 84 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 85 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 86 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 87 |
+
|
| 88 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 89 |
+
# or involve a numerics vs. perf tradeoff
|
| 90 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 91 |
+
# about 20% more numerical error, but slightly faster.
|
| 92 |
+
|
| 93 |
+
# Define strides of inputs
|
| 94 |
+
stride_qz, stride_qh, stride_qm, stride_qd = 4521984, 128, 4096, 1
|
| 95 |
+
stride_kz, stride_kh, stride_kn, stride_kd = 1130496, 128, 1024, 1
|
| 96 |
+
stride_vz, stride_vh, stride_vn, stride_vd = 1130496, 128, 1024, 1
|
| 97 |
+
stride_doz, stride_doh, stride_dom, stride_dod = 4521984, 141312, 128, 1
|
| 98 |
+
|
| 99 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4521984, 128, 4096, 1
|
| 100 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1130496, 128, 1024, 1
|
| 101 |
+
|
| 102 |
+
ZQ = 1
|
| 103 |
+
HQ = 32
|
| 104 |
+
HKV = 8
|
| 105 |
+
Q_LEN = 1104
|
| 106 |
+
ZKV = 1
|
| 107 |
+
KV_LEN = 1104
|
| 108 |
+
|
| 109 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 110 |
+
|
| 111 |
+
pid = tl.program_id(0).to(INDEX_DTYPE)
|
| 112 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 113 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 114 |
+
|
| 115 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
|
| 116 |
+
off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
|
| 117 |
+
off_zkv = off_zq % ZKV # kv batch idx
|
| 118 |
+
|
| 119 |
+
SPARSE_Z = 1
|
| 120 |
+
SPARSE_HQ = 1
|
| 121 |
+
|
| 122 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 123 |
+
|
| 124 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
| 125 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
| 126 |
+
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 127 |
+
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 128 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
| 129 |
+
|
| 130 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 131 |
+
K += k_adj
|
| 132 |
+
V += v_adj
|
| 133 |
+
DV += dv_adj
|
| 134 |
+
|
| 135 |
+
RCP_LN2 = 1.44269504
|
| 136 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 137 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 138 |
+
|
| 139 |
+
if pid >= NUM_KV_BLOCKS:
|
| 140 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 141 |
+
# THIS BLOCK DOES DQ
|
| 142 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 143 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 144 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 145 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 146 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 147 |
+
stride_kv_num_blks_h = 9
|
| 148 |
+
stride_kv_idx_h = 81
|
| 149 |
+
stride_kv_idx_m = 9
|
| 150 |
+
|
| 151 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 152 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 153 |
+
|
| 154 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 155 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 156 |
+
|
| 157 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 158 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
| 159 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
| 160 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
| 161 |
+
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 162 |
+
|
| 163 |
+
Q2 = Q + q_adj2
|
| 164 |
+
DO2 = DO + do_adj2
|
| 165 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 166 |
+
# if Q is broadcasted)
|
| 167 |
+
DQ2 = DQ + dq_adj2
|
| 168 |
+
LSE2 = LSE + off_chz2
|
| 169 |
+
DELTA2 = DELTA + off_chz2
|
| 170 |
+
|
| 171 |
+
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 172 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 173 |
+
|
| 174 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 175 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 176 |
+
|
| 177 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 178 |
+
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 179 |
+
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 180 |
+
|
| 181 |
+
if PRESCALE_QK:
|
| 182 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 183 |
+
|
| 184 |
+
if IS_DIVISIBLE:
|
| 185 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 186 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 187 |
+
else:
|
| 188 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 189 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 190 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 191 |
+
lse = lse[:, None]
|
| 192 |
+
|
| 193 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 194 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 195 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 196 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 197 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 198 |
+
|
| 199 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 200 |
+
dq = bwd_dq_inner(
|
| 201 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 202 |
+
K, V,
|
| 203 |
+
dq, q, do, Di, lse,
|
| 204 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 205 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 206 |
+
kv_indices, sparse_kv_num_blocks,
|
| 207 |
+
MATMUL_PRECISION,
|
| 208 |
+
IS_FULL_BLOCKS=False,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if HAS_FULL_BLOCKS:
|
| 212 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 213 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 214 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 215 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 216 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 217 |
+
|
| 218 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 219 |
+
dq = bwd_dq_inner(
|
| 220 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 221 |
+
K, V,
|
| 222 |
+
dq, q, do, Di, lse,
|
| 223 |
+
off_zq, off_hq2, offs_m2, offs_n2,
|
| 224 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 225 |
+
kv_indices, sparse_kv_num_blocks,
|
| 226 |
+
MATMUL_PRECISION,
|
| 227 |
+
IS_FULL_BLOCKS=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Write back dQ.
|
| 231 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 232 |
+
dq *= SM_SCALE
|
| 233 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 234 |
+
tl.store(dq_ptrs, dq)
|
| 235 |
+
else:
|
| 236 |
+
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
| 237 |
+
else:
|
| 238 |
+
# THIS BLOCK DOES DK & DV
|
| 239 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 240 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 241 |
+
|
| 242 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 243 |
+
|
| 244 |
+
stride_q_num_blks_h = 9
|
| 245 |
+
stride_q_idx_h = 81
|
| 246 |
+
stride_q_idx_n = 9
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 250 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 251 |
+
|
| 252 |
+
start_n1 = pid * BLOCK_N1
|
| 253 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 254 |
+
|
| 255 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 256 |
+
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 257 |
+
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 258 |
+
|
| 259 |
+
if PRESCALE_QK:
|
| 260 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 261 |
+
|
| 262 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 263 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 264 |
+
|
| 265 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
| 266 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
| 267 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
| 268 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
| 269 |
+
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 270 |
+
|
| 271 |
+
Q1 = Q + q_adj1
|
| 272 |
+
DO1 = DO + do_adj1
|
| 273 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 274 |
+
# if Q is broadcasted)
|
| 275 |
+
LSE1 = LSE + off_chz1
|
| 276 |
+
DELTA1 = DELTA + off_chz1
|
| 277 |
+
|
| 278 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 279 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 280 |
+
|
| 281 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 282 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 283 |
+
|
| 284 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 286 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 287 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 288 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 289 |
+
|
| 290 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 291 |
+
dk, dv = bwd_dkdv_inner(
|
| 292 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 293 |
+
Q1, DO1, DELTA1, LSE1,
|
| 294 |
+
dk, dv, k, v,
|
| 295 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 296 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 297 |
+
q_indices, sparse_q_num_blocks,
|
| 298 |
+
MATMUL_PRECISION,
|
| 299 |
+
IS_FULL_BLOCKS=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if HAS_FULL_BLOCKS:
|
| 304 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 305 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 306 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 307 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 308 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 309 |
+
|
| 310 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 311 |
+
dk, dv = bwd_dkdv_inner(
|
| 312 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 313 |
+
Q1, DO1, DELTA1, LSE1,
|
| 314 |
+
dk, dv, k, v,
|
| 315 |
+
off_zq, off_hq1, offs_n1, offs_m1,
|
| 316 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 317 |
+
q_indices, sparse_q_num_blocks,
|
| 318 |
+
MATMUL_PRECISION,
|
| 319 |
+
IS_FULL_BLOCKS=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Write back dV and dK.
|
| 323 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 324 |
+
|
| 325 |
+
index_n = offs_n1[:, None]
|
| 326 |
+
index_k = offs_k[None, :]
|
| 327 |
+
index_v = offs_v[None, :]
|
| 328 |
+
|
| 329 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 330 |
+
tl.store(dv_ptrs, dv)
|
| 331 |
+
else:
|
| 332 |
+
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
| 333 |
+
|
| 334 |
+
dk *= SM_SCALE
|
| 335 |
+
|
| 336 |
+
if SAFE_HEAD_DIM:
|
| 337 |
+
mask = index_n < KV_LEN
|
| 338 |
+
else:
|
| 339 |
+
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
| 340 |
+
|
| 341 |
+
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
| 342 |
+
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
| 343 |
+
tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
|
| 344 |
+
xindex = index_k + 128*index_n + 141312*off_hkv + 1130496*off_zq
|
| 345 |
+
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
|
| 346 |
+
|
| 347 |
+
@triton.jit
|
| 348 |
+
def bwd_dq_inner(
|
| 349 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 350 |
+
K, V, # pointers
|
| 351 |
+
dq, q, do, Di, lse,
|
| 352 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 353 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 354 |
+
kv_indices, sparse_kv_num_blocks,
|
| 355 |
+
MATMUL_PRECISION,
|
| 356 |
+
IS_FULL_BLOCKS,
|
| 357 |
+
):
|
| 358 |
+
PRESCALE_QK : tl.constexpr = False
|
| 359 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 360 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 361 |
+
WRITE_DQ : tl.constexpr = True
|
| 362 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 363 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 364 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 365 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 366 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 367 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 368 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 369 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 370 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 371 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 372 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 373 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 374 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 375 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 376 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 377 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 378 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 379 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 380 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 381 |
+
|
| 382 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 383 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 384 |
+
Q_LEN = 1104
|
| 385 |
+
KV_LEN = 1104
|
| 386 |
+
|
| 387 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 388 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 389 |
+
|
| 390 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 391 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 392 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 393 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 394 |
+
|
| 395 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 396 |
+
|
| 397 |
+
for start_n in range(0, hi):
|
| 398 |
+
dq = bwd_dq_block_mn(
|
| 399 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 400 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 401 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 402 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 403 |
+
kv_indices, sparse_kv_num_blocks,
|
| 404 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 405 |
+
IS_FULL_BLOCKS,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Increment pointers.
|
| 409 |
+
offset = get_offset_for_next_block(
|
| 410 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 411 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
kT_ptrs += offset * stride_kn
|
| 415 |
+
vT_ptrs += offset * stride_vn
|
| 416 |
+
|
| 417 |
+
offs_n2 += offset
|
| 418 |
+
|
| 419 |
+
return dq
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@triton.jit
|
| 423 |
+
def bwd_dq_block_mn(
|
| 424 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 425 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 426 |
+
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
| 427 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 428 |
+
kv_indices, sparse_kv_num_blocks,
|
| 429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 430 |
+
IS_FULL_BLOCKS,
|
| 431 |
+
):
|
| 432 |
+
PRESCALE_QK : tl.constexpr = False
|
| 433 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 434 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 435 |
+
WRITE_DQ : tl.constexpr = True
|
| 436 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 437 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 438 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 439 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 440 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 441 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 442 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 443 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 444 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 445 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 446 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 447 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 448 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 449 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 450 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 451 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 452 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 453 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 454 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# NB reversed order to since K is transposed
|
| 458 |
+
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
| 459 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 460 |
+
if not PRESCALE_QK:
|
| 461 |
+
qk *= SM_SCALE
|
| 462 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 463 |
+
pre_mod_scores = qk
|
| 464 |
+
n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
|
| 465 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
| 466 |
+
# that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
|
| 467 |
+
m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
|
| 468 |
+
|
| 469 |
+
tmp0 = (qk)
|
| 470 |
+
post_mod_scores = tmp0
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if not IS_DIVISIBLE:
|
| 476 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 477 |
+
|
| 478 |
+
if not IS_FULL_BLOCKS:
|
| 479 |
+
tmp1 = (m)
|
| 480 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 481 |
+
tmp3 = tmp1 < tmp2
|
| 482 |
+
tmp4 = (n)
|
| 483 |
+
tmp5 = tmp4 <= tmp1
|
| 484 |
+
tmp6 = tmp3 & tmp5
|
| 485 |
+
tmp7 = tmp1 >= tmp2
|
| 486 |
+
tmp8 = tmp4 < tmp2
|
| 487 |
+
tmp9 = tmp7 & tmp8
|
| 488 |
+
tmp10 = tmp8 == 0
|
| 489 |
+
tmp11 = tmp7 & tmp10
|
| 490 |
+
tmp12 = tmp1 - tmp2
|
| 491 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 492 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 493 |
+
tmp15 = tmp4 - tmp2
|
| 494 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 495 |
+
tmp17 = tmp14 == tmp16
|
| 496 |
+
tmp18 = tmp11 & tmp17
|
| 497 |
+
tmp19 = tmp9 | tmp18
|
| 498 |
+
tmp20 = tmp6 | tmp19
|
| 499 |
+
mask_mod_output = tmp20
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# apply mask for partial masked block
|
| 503 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 504 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 505 |
+
if not PRESCALE_QK:
|
| 506 |
+
post_mod_scores *= RCP_LN2
|
| 507 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 508 |
+
# Compute dP and dS.
|
| 509 |
+
# NB reversed order to since V is transposed
|
| 510 |
+
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
| 511 |
+
|
| 512 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 513 |
+
ds = p * (dp - Di[:, None])
|
| 514 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 515 |
+
tmp21 = (ds)
|
| 516 |
+
grad_scores = tmp21
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if not IS_DIVISIBLE:
|
| 520 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 521 |
+
|
| 522 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 523 |
+
if WRITE_DQ:
|
| 524 |
+
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
| 525 |
+
|
| 526 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 527 |
+
ds = grad_scores
|
| 528 |
+
|
| 529 |
+
if not IS_FULL_BLOCKS:
|
| 530 |
+
# (grads) apply mask for partially unmasked block
|
| 531 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 532 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 533 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 534 |
+
# Compute dQ.
|
| 535 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 536 |
+
|
| 537 |
+
return dq
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@triton.jit
|
| 541 |
+
def bwd_dkdv_inner(
|
| 542 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 543 |
+
Q, DO, DELTA, LSE, # pointers
|
| 544 |
+
dk, dv, k, v,
|
| 545 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 546 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 547 |
+
q_indices, sparse_q_num_blocks,
|
| 548 |
+
MATMUL_PRECISION,
|
| 549 |
+
IS_FULL_BLOCKS,
|
| 550 |
+
):
|
| 551 |
+
PRESCALE_QK : tl.constexpr = False
|
| 552 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 553 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 554 |
+
WRITE_DQ : tl.constexpr = True
|
| 555 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 556 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 557 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 558 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 561 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 562 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 563 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 565 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 566 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 567 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 568 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 569 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 570 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 571 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 572 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 573 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 574 |
+
|
| 575 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
Q_LEN = 1104
|
| 578 |
+
KV_LEN = 1104
|
| 579 |
+
|
| 580 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 581 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 582 |
+
|
| 583 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 584 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 585 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 586 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 587 |
+
|
| 588 |
+
# The minimum is needed to handle the case where we run with a super large
|
| 589 |
+
# SPARSE_BLOCK_SIZE (i.e. no block-mask!)
|
| 590 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 591 |
+
|
| 592 |
+
for start_m in range(0, hi):
|
| 593 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 594 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 595 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 596 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 597 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 598 |
+
q_indices, sparse_q_num_blocks,
|
| 599 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
# Increment pointers.
|
| 603 |
+
offset = get_offset_for_next_block(
|
| 604 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 605 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
qT_ptrs += offset * stride_qm
|
| 609 |
+
do_ptrs += offset * stride_dom
|
| 610 |
+
offs_m1 += offset
|
| 611 |
+
|
| 612 |
+
return dk, dv
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@triton.jit
|
| 616 |
+
def bwd_dkdv_block_mn(
|
| 617 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
|
| 618 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 619 |
+
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
| 620 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 621 |
+
q_indices, sparse_q_num_blocks,
|
| 622 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 623 |
+
IS_FULL_BLOCKS,
|
| 624 |
+
):
|
| 625 |
+
PRESCALE_QK : tl.constexpr = False
|
| 626 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 627 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 628 |
+
WRITE_DQ : tl.constexpr = True
|
| 629 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 630 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 631 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 632 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 633 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 634 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 635 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 636 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 637 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 638 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 639 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 640 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 641 |
+
BLOCK_M1 : tl.constexpr = 64
|
| 642 |
+
BLOCK_N1 : tl.constexpr = 128
|
| 643 |
+
BLOCK_M2 : tl.constexpr = 128
|
| 644 |
+
BLOCK_N2 : tl.constexpr = 64
|
| 645 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 646 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 647 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# NB reversed order since Q is transposed
|
| 651 |
+
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
| 652 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 653 |
+
if IS_DIVISIBLE:
|
| 654 |
+
lse = tl.load(LSE + offs_m1)
|
| 655 |
+
else:
|
| 656 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 657 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 658 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 659 |
+
if not PRESCALE_QK:
|
| 660 |
+
qkT *= SM_SCALE
|
| 661 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 662 |
+
m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
|
| 663 |
+
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
| 664 |
+
# that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
|
| 665 |
+
n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
|
| 666 |
+
|
| 667 |
+
pre_mod_scores = qkT
|
| 668 |
+
tmp22 = (qkT)
|
| 669 |
+
post_mod_scores = tmp22
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if not IS_DIVISIBLE:
|
| 674 |
+
post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
|
| 675 |
+
|
| 676 |
+
if not IS_FULL_BLOCKS:
|
| 677 |
+
tmp23 = (m)
|
| 678 |
+
tmp24 = tl.full([1], 0, tl.int32)
|
| 679 |
+
tmp25 = tmp23 < tmp24
|
| 680 |
+
tmp26 = (n)
|
| 681 |
+
tmp27 = tmp26 <= tmp23
|
| 682 |
+
tmp28 = tmp25 & tmp27
|
| 683 |
+
tmp29 = tmp23 >= tmp24
|
| 684 |
+
tmp30 = tmp26 < tmp24
|
| 685 |
+
tmp31 = tmp29 & tmp30
|
| 686 |
+
tmp32 = tmp30 == 0
|
| 687 |
+
tmp33 = tmp29 & tmp32
|
| 688 |
+
tmp34 = tmp23 - tmp24
|
| 689 |
+
tmp35 = tl.full([1], 16, tl.int32)
|
| 690 |
+
tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35)
|
| 691 |
+
tmp37 = tmp26 - tmp24
|
| 692 |
+
tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35)
|
| 693 |
+
tmp39 = tmp36 == tmp38
|
| 694 |
+
tmp40 = tmp33 & tmp39
|
| 695 |
+
tmp41 = tmp31 | tmp40
|
| 696 |
+
tmp42 = tmp28 | tmp41
|
| 697 |
+
mask_mod_output = tmp42
|
| 698 |
+
|
| 699 |
+
# (grads) apply mask for fully masked block
|
| 700 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 701 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 702 |
+
if not PRESCALE_QK:
|
| 703 |
+
post_mod_scores *= RCP_LN2
|
| 704 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 705 |
+
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
| 706 |
+
# Compute dV.
|
| 707 |
+
ppT = pT
|
| 708 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 709 |
+
if IS_DIVISIBLE:
|
| 710 |
+
Di = tl.load(DELTA + offs_m1)
|
| 711 |
+
else:
|
| 712 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 713 |
+
# Compute dP and dS.
|
| 714 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 715 |
+
dsT = pT * (dpT - Di[None, :])
|
| 716 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 717 |
+
tmp43 = (dsT)
|
| 718 |
+
grad_scores = tmp43
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if not IS_DIVISIBLE:
|
| 723 |
+
grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
|
| 724 |
+
|
| 725 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
| 726 |
+
if not WRITE_DQ:
|
| 727 |
+
idx_b = off_z
|
| 728 |
+
idx_h = off_hq
|
| 729 |
+
idx_m = m
|
| 730 |
+
idx_n = n
|
| 731 |
+
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
| 732 |
+
|
| 733 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 734 |
+
dsT = grad_scores
|
| 735 |
+
if not IS_FULL_BLOCKS:
|
| 736 |
+
# (grads) apply mask for partially unmasked block
|
| 737 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 738 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 739 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 740 |
+
|
| 741 |
+
return dk, dv
|
| 742 |
+
|
| 743 |
+
# Utility triton funcs
|
| 744 |
+
@triton.jit
|
| 745 |
+
def get_offset_for_next_block(
|
| 746 |
+
loop_iter, col_indices, total_blocks,
|
| 747 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 748 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 749 |
+
):
|
| 750 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 751 |
+
return BLOCK
|
| 752 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 753 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 754 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 755 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 756 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 757 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 758 |
+
return offset
|
| 759 |
+
|
| 760 |
+
@triton.jit
|
| 761 |
+
def get_bounded_indices(indices, max_len=None):
|
| 762 |
+
return indices % max_len if max_len is not None else indices
|
| 763 |
+
|
| 764 |
+
@triton.jit
|
| 765 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 766 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 767 |
+
return tl.load(block_ptr)
|
| 768 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 769 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 770 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 771 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 772 |
+
else:
|
| 773 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 774 |
+
|
| 775 |
+
@triton.jit
|
| 776 |
+
def load_checked_2d(
|
| 777 |
+
ptr,
|
| 778 |
+
offs_m,
|
| 779 |
+
offs_n,
|
| 780 |
+
stride_m,
|
| 781 |
+
stride_n,
|
| 782 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 783 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 784 |
+
M_LEN: tl.constexpr,
|
| 785 |
+
N_LEN: tl.constexpr,
|
| 786 |
+
):
|
| 787 |
+
# Calculate final pointer if strides are provided
|
| 788 |
+
if stride_m is not None and stride_n is not None:
|
| 789 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 790 |
+
|
| 791 |
+
# Handle all masking cases
|
| 792 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 793 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 794 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 795 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 796 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 797 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 798 |
+
else: # Both divisible
|
| 799 |
+
return tl.load(ptr)
|
progress/SpecForge/cache/compiled_kernels/6b/c6b364ingwlftc5camjox4wdd5z5l4famigf6ojv4cyji4ju37fy.py
ADDED
|
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['5_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/za/czavreibvu56vq4htyyxtbexpf3r3xsyhsclf2t4oq5z37c2h7e5.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# flex_attention => flex_attention
|
| 41 |
+
# Graph fragment:
|
| 42 |
+
# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1]
|
| 43 |
+
# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1]
|
| 44 |
+
# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1]
|
| 45 |
+
# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0]
|
| 46 |
+
# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
|
| 47 |
+
# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg9_1]
|
| 48 |
+
# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg6_1]
|
| 49 |
+
# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg10_1]
|
| 50 |
+
# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg11_1]
|
| 51 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 52 |
+
# return %buf2
|
| 53 |
+
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
|
| 54 |
+
import triton
|
| 55 |
+
import triton.language as tl
|
| 56 |
+
|
| 57 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 58 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 59 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.template(
|
| 62 |
+
|
| 63 |
+
num_stages=3,
|
| 64 |
+
num_warps=2,
|
| 65 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 66 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
|
| 67 |
+
|
| 68 |
+
)
|
| 69 |
+
@triton.jit
|
| 70 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 71 |
+
PRESCALE_QK : tl.constexpr = False
|
| 72 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 73 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 74 |
+
WRITE_DQ : tl.constexpr = True
|
| 75 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 76 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 77 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 78 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 79 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 80 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 81 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 82 |
+
SPLIT_KV : tl.constexpr = 32
|
| 83 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 84 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 85 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 86 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 87 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 88 |
+
BLOCK_M : tl.constexpr = 512
|
| 89 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 90 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 91 |
+
BLOCK_N : tl.constexpr = 64
|
| 92 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 93 |
+
USE_TMA : tl.constexpr = False
|
| 94 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 95 |
+
Q = arg_Q
|
| 96 |
+
K = arg_K
|
| 97 |
+
V = arg_V
|
| 98 |
+
M = arg_M
|
| 99 |
+
L = arg_L
|
| 100 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 101 |
+
KV_IDX = arg_KV_IDX
|
| 102 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 103 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 104 |
+
|
| 105 |
+
# Sub notation for this kernel:
|
| 106 |
+
# Q: Query, K: Key, V: Value
|
| 107 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 108 |
+
# M: Number of queries, N: Number of keys/values
|
| 109 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 110 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 111 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 112 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 113 |
+
# (Modifiable) Config options:
|
| 114 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 115 |
+
# TILE_KV: length of each local KV split
|
| 116 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 117 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 118 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 119 |
+
#
|
| 120 |
+
# change of base out of the loop
|
| 121 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 122 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 123 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 124 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 125 |
+
|
| 126 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 127 |
+
#
|
| 128 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 129 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 130 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 131 |
+
#
|
| 132 |
+
#
|
| 133 |
+
# Output: ACC output accumulated across local KV split.
|
| 134 |
+
|
| 135 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 136 |
+
|
| 137 |
+
# Define Q Strides
|
| 138 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
|
| 139 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 140 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 141 |
+
stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
|
| 142 |
+
stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
Z = 1
|
| 146 |
+
ZKV = 1
|
| 147 |
+
HKV = 8
|
| 148 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 149 |
+
HQ = HKV * G
|
| 150 |
+
Q_LEN = ks0
|
| 151 |
+
KV_LEN = ks1
|
| 152 |
+
|
| 153 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 154 |
+
|
| 155 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 156 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 157 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 158 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 159 |
+
|
| 160 |
+
off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
|
| 161 |
+
off_zkv = off_z % ZKV
|
| 162 |
+
off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
|
| 163 |
+
off_t = tl.program_id(1).to(INDEX_DTYPE)
|
| 164 |
+
|
| 165 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 166 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 167 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 168 |
+
|
| 169 |
+
K = K + k_offset
|
| 170 |
+
V = V + v_offset
|
| 171 |
+
|
| 172 |
+
SPARSE_Z = 1
|
| 173 |
+
SPARSE_HQ = 1
|
| 174 |
+
|
| 175 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 176 |
+
sparse_idx_h = off_hkv % SPARSE_HQ
|
| 177 |
+
|
| 178 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 179 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 180 |
+
|
| 181 |
+
# initialize pointer to m and l
|
| 182 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 183 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 184 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 185 |
+
|
| 186 |
+
# initialize offsets
|
| 187 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 188 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 189 |
+
off_g = tl.arange(0, G) # [G]
|
| 190 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 191 |
+
offs_hq = offs_g + off_hkv * G
|
| 192 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 193 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 194 |
+
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 195 |
+
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 196 |
+
|
| 197 |
+
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
| 198 |
+
stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
|
| 199 |
+
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
| 200 |
+
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
|
| 201 |
+
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
| 202 |
+
|
| 203 |
+
# Calculate KV blocks that belong this CTA.
|
| 204 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 205 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 206 |
+
|
| 207 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 208 |
+
|
| 209 |
+
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 210 |
+
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
| 211 |
+
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 212 |
+
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
| 213 |
+
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
| 214 |
+
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
| 215 |
+
else:
|
| 216 |
+
q = tl.load(Q + q_offset + q_range)
|
| 217 |
+
|
| 218 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 222 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 223 |
+
# Offset the kv_indices tensor by the correct batch and head
|
| 224 |
+
kv_indices = KV_IDX + sparse_idx_hz_offset
|
| 225 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
| 226 |
+
MAX_KV_IDX = 1
|
| 227 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 228 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 229 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 230 |
+
# first kv block we're loading
|
| 231 |
+
|
| 232 |
+
# last valid block according to sparse mask
|
| 233 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 234 |
+
|
| 235 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 236 |
+
|
| 237 |
+
desc_k = None
|
| 238 |
+
desc_v = None
|
| 239 |
+
|
| 240 |
+
acc, l_i, m_i = forward_inner(
|
| 241 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 242 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 243 |
+
# accumulatd values
|
| 244 |
+
acc, l_i, m_i,
|
| 245 |
+
#offsets
|
| 246 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 247 |
+
off_n,
|
| 248 |
+
#block sparse data
|
| 249 |
+
kv_indices, kv_num_blocks,
|
| 250 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 251 |
+
MATMUL_PRECISION,
|
| 252 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 253 |
+
IS_FULL_BLOCKS=False,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 258 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 259 |
+
# apply mask_mod to them - only score_mod
|
| 260 |
+
if HAS_FULL_BLOCKS:
|
| 261 |
+
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
| 262 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
| 263 |
+
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
| 264 |
+
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
| 265 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
| 266 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 267 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 268 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 269 |
+
|
| 270 |
+
# last valid block according to sparse mask
|
| 271 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 272 |
+
|
| 273 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 274 |
+
|
| 275 |
+
acc, l_i, m_i = forward_inner(
|
| 276 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 277 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 278 |
+
# accumulatd values
|
| 279 |
+
acc, l_i, m_i,
|
| 280 |
+
#offsets
|
| 281 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 282 |
+
off_n,
|
| 283 |
+
#block sparse data
|
| 284 |
+
kv_indices, kv_num_blocks,
|
| 285 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 286 |
+
MATMUL_PRECISION,
|
| 287 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 288 |
+
IS_FULL_BLOCKS=True,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 292 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 293 |
+
|
| 294 |
+
M_block_ptr = tl.make_block_ptr(
|
| 295 |
+
base=M + m_offset,
|
| 296 |
+
shape=(G, Q_LEN), # (G, M)
|
| 297 |
+
strides=(stride_mh, stride_mm),
|
| 298 |
+
offsets=(off_hkv*G, 0),
|
| 299 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 300 |
+
order=(1, 0)
|
| 301 |
+
)
|
| 302 |
+
L_block_ptr = tl.make_block_ptr(
|
| 303 |
+
base=L + l_offset,
|
| 304 |
+
shape=(G, Q_LEN), # (G, M)
|
| 305 |
+
strides=(stride_lh, stride_lm),
|
| 306 |
+
offsets=(off_hkv*G, 0),
|
| 307 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 308 |
+
order=(1, 0)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 312 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 313 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 314 |
+
if SAFE_M_BOUNDARY:
|
| 315 |
+
tl.store(M_block_ptr, m_i)
|
| 316 |
+
tl.store(L_block_ptr, l_i)
|
| 317 |
+
else:
|
| 318 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 319 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 320 |
+
|
| 321 |
+
# -- store output
|
| 322 |
+
idx_z = off_z
|
| 323 |
+
idx_t = off_t
|
| 324 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 325 |
+
idx_m = off_m[None, :, None]
|
| 326 |
+
idx_d = offs_vd[None, None, :]
|
| 327 |
+
|
| 328 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 329 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 330 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
|
| 331 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Utility triton funcs
|
| 335 |
+
@triton.jit
|
| 336 |
+
def get_offset_for_next_block(
|
| 337 |
+
loop_iter, col_indices, total_blocks,
|
| 338 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 339 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 340 |
+
):
|
| 341 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 342 |
+
return BLOCK
|
| 343 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 344 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 345 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 346 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 347 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 348 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 349 |
+
return offset
|
| 350 |
+
|
| 351 |
+
@triton.jit
|
| 352 |
+
def get_bounded_indices(indices, max_len=None):
|
| 353 |
+
return indices % max_len if max_len is not None else indices
|
| 354 |
+
|
| 355 |
+
@triton.jit
|
| 356 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 357 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 358 |
+
return tl.load(block_ptr)
|
| 359 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 360 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 361 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 362 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 363 |
+
else:
|
| 364 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 365 |
+
|
| 366 |
+
@triton.jit
|
| 367 |
+
def load_checked_2d(
|
| 368 |
+
ptr,
|
| 369 |
+
offs_m,
|
| 370 |
+
offs_n,
|
| 371 |
+
stride_m,
|
| 372 |
+
stride_n,
|
| 373 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 374 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 375 |
+
M_LEN: tl.constexpr,
|
| 376 |
+
N_LEN: tl.constexpr,
|
| 377 |
+
):
|
| 378 |
+
# Calculate final pointer if strides are provided
|
| 379 |
+
if stride_m is not None and stride_n is not None:
|
| 380 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 381 |
+
|
| 382 |
+
# Handle all masking cases
|
| 383 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 384 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 385 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 386 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 387 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 388 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 389 |
+
else: # Both divisible
|
| 390 |
+
return tl.load(ptr)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# Common Imports
|
| 394 |
+
@triton.jit
|
| 395 |
+
def forward_block_mn(
|
| 396 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 397 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 398 |
+
# accumulated values
|
| 399 |
+
acc, l_i, m_i,
|
| 400 |
+
# Offsets
|
| 401 |
+
off_z, off_h, offs_m, offs_n,
|
| 402 |
+
# Offsets needed for TMA loads
|
| 403 |
+
kv_start,
|
| 404 |
+
kv_offset,
|
| 405 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 406 |
+
# Strides for K and V
|
| 407 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 408 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 409 |
+
|
| 410 |
+
):
|
| 411 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 412 |
+
PRESCALE_QK : tl.constexpr = False
|
| 413 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 414 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 415 |
+
WRITE_DQ : tl.constexpr = True
|
| 416 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 417 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 418 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 419 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 420 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 421 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 422 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 423 |
+
SPLIT_KV : tl.constexpr = 32
|
| 424 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 425 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 426 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 427 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 428 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 429 |
+
BLOCK_M : tl.constexpr = 512
|
| 430 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 431 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 432 |
+
BLOCK_N : tl.constexpr = 64
|
| 433 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 434 |
+
USE_TMA : tl.constexpr = False
|
| 435 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# -- load k --
|
| 439 |
+
# NB reversed order to since K is transposed
|
| 440 |
+
kv_base_offset = kv_start + kv_offset
|
| 441 |
+
|
| 442 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 443 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 444 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 445 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 446 |
+
|
| 447 |
+
k = tl.trans(k)
|
| 448 |
+
# -- compute qk ---
|
| 449 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 450 |
+
if not PRESCALE_QK:
|
| 451 |
+
qk *= SM_SCALE
|
| 452 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 453 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 454 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 455 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 456 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 457 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 458 |
+
|
| 459 |
+
tmp0 = (qk)
|
| 460 |
+
post_mod_scores = tmp0
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 464 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 465 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 466 |
+
|
| 467 |
+
if not IS_FULL_BLOCKS:
|
| 468 |
+
tmp1 = (m)
|
| 469 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 470 |
+
tmp3 = tmp1 < tmp2
|
| 471 |
+
tmp4 = (n)
|
| 472 |
+
tmp5 = tmp4 <= tmp1
|
| 473 |
+
tmp6 = tmp3 & tmp5
|
| 474 |
+
tmp7 = tmp1 >= tmp2
|
| 475 |
+
tmp8 = tmp4 < tmp2
|
| 476 |
+
tmp9 = tmp7 & tmp8
|
| 477 |
+
tmp10 = tmp8 == 0
|
| 478 |
+
tmp11 = tmp7 & tmp10
|
| 479 |
+
tmp12 = tmp1 - tmp2
|
| 480 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 481 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 482 |
+
tmp15 = tmp4 - tmp2
|
| 483 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 484 |
+
tmp17 = tmp14 == tmp16
|
| 485 |
+
tmp18 = tmp11 & tmp17
|
| 486 |
+
tmp19 = tmp9 | tmp18
|
| 487 |
+
tmp20 = tmp6 | tmp19
|
| 488 |
+
mask_mod_output = tmp20
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 492 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 493 |
+
# apply mask for partially unmasked blocks
|
| 494 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 495 |
+
|
| 496 |
+
if not PRESCALE_QK:
|
| 497 |
+
post_mod_scores *= RCP_LN2
|
| 498 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 499 |
+
|
| 500 |
+
# -- compute scaling constant ---
|
| 501 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 502 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 503 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 504 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 505 |
+
else:
|
| 506 |
+
m_ij_masked = m_ij
|
| 507 |
+
|
| 508 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 509 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 510 |
+
|
| 511 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 512 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 513 |
+
# m_ij
|
| 514 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 515 |
+
# # -- scale and update acc --
|
| 516 |
+
acc = acc * alpha[:, None]
|
| 517 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 518 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 519 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 520 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 521 |
+
|
| 522 |
+
# -- update m_i
|
| 523 |
+
m_i = m_ij
|
| 524 |
+
|
| 525 |
+
return acc, l_i, m_i
|
| 526 |
+
|
| 527 |
+
@triton.jit
|
| 528 |
+
def forward_inner(
|
| 529 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 530 |
+
q, K, V,
|
| 531 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 532 |
+
# accumulated values
|
| 533 |
+
acc, l_i, m_i,
|
| 534 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 535 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 536 |
+
off_z, off_h, offs_m, offs_n,
|
| 537 |
+
# Offsets needed for TMA loads
|
| 538 |
+
kv_start,
|
| 539 |
+
# blocksparse data
|
| 540 |
+
kv_indices, kv_num_blocks,
|
| 541 |
+
# start kv and end kv block
|
| 542 |
+
block_n_start, block_n_end,
|
| 543 |
+
MATMUL_PRECISION,
|
| 544 |
+
# Strides for K and V
|
| 545 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 546 |
+
IS_FULL_BLOCKS,
|
| 547 |
+
):
|
| 548 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 549 |
+
PRESCALE_QK : tl.constexpr = False
|
| 550 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 551 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 552 |
+
WRITE_DQ : tl.constexpr = True
|
| 553 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 554 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 555 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 556 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 557 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 558 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 559 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 560 |
+
SPLIT_KV : tl.constexpr = 32
|
| 561 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 562 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 563 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 564 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 565 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 566 |
+
BLOCK_M : tl.constexpr = 512
|
| 567 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 568 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 569 |
+
BLOCK_N : tl.constexpr = 64
|
| 570 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 571 |
+
USE_TMA : tl.constexpr = False
|
| 572 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 576 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 577 |
+
|
| 578 |
+
if PRESCALE_QK:
|
| 579 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 580 |
+
|
| 581 |
+
kv_offset = 0
|
| 582 |
+
|
| 583 |
+
# loop over k, v and update accumulator until block_n_end
|
| 584 |
+
for start_n in range(block_n_start, block_n_end):
|
| 585 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 586 |
+
if IS_DIVISIBLE:
|
| 587 |
+
acc, l_i, m_i = forward_block_mn(
|
| 588 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 589 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 590 |
+
# accumulated values
|
| 591 |
+
acc, l_i, m_i,
|
| 592 |
+
# Offsets
|
| 593 |
+
off_z, off_h, offs_m, offs_n,
|
| 594 |
+
# Offsets needed for TMA loads
|
| 595 |
+
kv_start,
|
| 596 |
+
kv_offset,
|
| 597 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 598 |
+
# Strides for K and V
|
| 599 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 600 |
+
IS_FULL_BLOCKS,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 604 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 605 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 606 |
+
# to the last block because it's faster a lot.
|
| 607 |
+
acc, l_i, m_i = forward_block_mn(
|
| 608 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 609 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 610 |
+
# accumulated values
|
| 611 |
+
acc, l_i, m_i,
|
| 612 |
+
# Offsets
|
| 613 |
+
off_z, off_h, offs_m, offs_n,
|
| 614 |
+
# Offsets needed for TMA loads
|
| 615 |
+
kv_start,
|
| 616 |
+
kv_offset,
|
| 617 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 618 |
+
# Strides for K and V
|
| 619 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 620 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
offset = get_offset_for_next_block(
|
| 626 |
+
start_n, kv_indices, kv_num_blocks,
|
| 627 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
offs_n = offs_n + offset
|
| 631 |
+
kv_offset += offset
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
return acc, l_i, m_i
|
| 635 |
+
''', device_str='cuda')
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7i/c7ijsdt7wst5xe64qslxvdevuoxlscozrh6zigwqavztuz3rptdj.py
|
| 639 |
+
# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
|
| 640 |
+
# Source node to ATen node mapping:
|
| 641 |
+
# flex_attention => flex_attention
|
| 642 |
+
# lse_scaled => mul_9
|
| 643 |
+
# Graph fragment:
|
| 644 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 645 |
+
# %buf4 : Tensor = PlaceHolder[target=buf4]
|
| 646 |
+
# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
|
| 647 |
+
# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
|
| 648 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 649 |
+
# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {})
|
| 650 |
+
# return %buf5,%buf7,%mul_9
|
| 651 |
+
triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', '''
|
| 652 |
+
import triton
|
| 653 |
+
import triton.language as tl
|
| 654 |
+
|
| 655 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 656 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 657 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 658 |
+
triton_helpers.set_driver_to_gpu()
|
| 659 |
+
|
| 660 |
+
@triton_heuristics.persistent_reduction(
|
| 661 |
+
size_hints={'x': 4096, 'r0_': 32},
|
| 662 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 663 |
+
filename=__file__,
|
| 664 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 665 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 666 |
+
)
|
| 667 |
+
@triton.jit
|
| 668 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 669 |
+
r0_numel = 32
|
| 670 |
+
R0_BLOCK: tl.constexpr = 32
|
| 671 |
+
rnumel = r0_numel
|
| 672 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 673 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 674 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 675 |
+
xmask = xindex < xnumel
|
| 676 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 677 |
+
r0_offset = 0
|
| 678 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 679 |
+
roffset = r0_offset
|
| 680 |
+
rindex = r0_index
|
| 681 |
+
r0_1 = r0_index
|
| 682 |
+
x0 = xindex
|
| 683 |
+
x2 = (xindex % ks0)
|
| 684 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 685 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 686 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 687 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 688 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 689 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 690 |
+
tmp6 = float("-inf")
|
| 691 |
+
tmp7 = tmp4 == tmp6
|
| 692 |
+
tmp8 = tmp0 - tmp4
|
| 693 |
+
tmp9 = 0.0
|
| 694 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 695 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 696 |
+
tmp12 = tmp5 * tmp11
|
| 697 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 698 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 699 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 700 |
+
tmp17 = 1.0
|
| 701 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 702 |
+
tmp19 = libdevice.log2(tmp18)
|
| 703 |
+
tmp20 = tmp19 + tmp4
|
| 704 |
+
tmp21 = 0.6931471805599453
|
| 705 |
+
tmp22 = tmp20 * tmp21
|
| 706 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 707 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 708 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
| 709 |
+
''', device_str='cuda')
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xf/cxfq6ic3efg6yukl55sige4jtfyzy5dlspo6ipsr5mew7q7i32dg.py
|
| 713 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 714 |
+
# Source node to ATen node mapping:
|
| 715 |
+
# flex_attention => flex_attention, getitem
|
| 716 |
+
# Graph fragment:
|
| 717 |
+
# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf2]
|
| 718 |
+
# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5]
|
| 719 |
+
# %buf3 : Tensor = PlaceHolder[target=buf3]
|
| 720 |
+
# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf8]
|
| 721 |
+
# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7]
|
| 722 |
+
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
|
| 723 |
+
# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {})
|
| 724 |
+
# return %buf8,%getitem
|
| 725 |
+
triton_per_fused_2 = async_compile.triton('triton_per_fused_2', '''
|
| 726 |
+
import triton
|
| 727 |
+
import triton.language as tl
|
| 728 |
+
|
| 729 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 730 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 731 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 732 |
+
triton_helpers.set_driver_to_gpu()
|
| 733 |
+
|
| 734 |
+
@triton_heuristics.persistent_reduction(
|
| 735 |
+
size_hints={'x': 524288, 'r0_': 32},
|
| 736 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 737 |
+
filename=__file__,
|
| 738 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
|
| 739 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 740 |
+
)
|
| 741 |
+
@triton.jit
|
| 742 |
+
def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 743 |
+
r0_numel = 32
|
| 744 |
+
R0_BLOCK: tl.constexpr = 32
|
| 745 |
+
rnumel = r0_numel
|
| 746 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 747 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 748 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 749 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 750 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 751 |
+
r0_offset = 0
|
| 752 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 753 |
+
roffset = r0_offset
|
| 754 |
+
rindex = r0_index
|
| 755 |
+
r0_2 = r0_index
|
| 756 |
+
x5 = xindex
|
| 757 |
+
x1 = xindex // 128
|
| 758 |
+
x0 = (xindex % 128)
|
| 759 |
+
x3 = ((xindex // 128) % ks0)
|
| 760 |
+
x4 = xindex // ks1
|
| 761 |
+
tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None)
|
| 762 |
+
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last')
|
| 763 |
+
tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last')
|
| 764 |
+
tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last')
|
| 765 |
+
tmp2 = float("-inf")
|
| 766 |
+
tmp3 = tmp1 == tmp2
|
| 767 |
+
tmp5 = tmp4 - tmp1
|
| 768 |
+
tmp6 = 0.0
|
| 769 |
+
tmp7 = tl.where(tmp3, tmp6, tmp5)
|
| 770 |
+
tmp8 = libdevice.exp2(tmp7)
|
| 771 |
+
tmp9 = tmp0 * tmp8
|
| 772 |
+
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
|
| 773 |
+
tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32)
|
| 774 |
+
tmp14 = 1.0
|
| 775 |
+
tmp15 = tl.where(tmp3, tmp14, tmp13)
|
| 776 |
+
tmp16 = (tmp12 / tmp15)
|
| 777 |
+
tmp17 = tmp16.to(tl.float32)
|
| 778 |
+
tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None)
|
| 779 |
+
''', device_str='cuda')
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
async_compile.wait(globals())
|
| 783 |
+
del async_compile
|
| 784 |
+
|
| 785 |
+
class Runner:
|
| 786 |
+
def __init__(self, partitions):
|
| 787 |
+
self.partitions = partitions
|
| 788 |
+
|
| 789 |
+
def recursively_apply_fns(self, fns):
|
| 790 |
+
new_callables = []
|
| 791 |
+
for fn, c in zip(fns, self.partitions):
|
| 792 |
+
new_callables.append(fn(c))
|
| 793 |
+
self.partitions = new_callables
|
| 794 |
+
|
| 795 |
+
def call(self, args):
|
| 796 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args
|
| 797 |
+
args.clear()
|
| 798 |
+
s50 = arg0_1
|
| 799 |
+
s0 = arg2_1
|
| 800 |
+
s43 = arg4_1
|
| 801 |
+
s37 = arg7_1
|
| 802 |
+
s71 = arg8_1
|
| 803 |
+
assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1))
|
| 804 |
+
assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1))
|
| 805 |
+
assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1))
|
| 806 |
+
assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 807 |
+
assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1))
|
| 808 |
+
assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1))
|
| 809 |
+
assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 810 |
+
assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1))
|
| 811 |
+
assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 812 |
+
assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1))
|
| 813 |
+
assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1))
|
| 814 |
+
with torch.cuda._DeviceGuard(7):
|
| 815 |
+
torch.cuda.set_device(7)
|
| 816 |
+
buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
|
| 817 |
+
buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32)
|
| 818 |
+
buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32)
|
| 819 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 820 |
+
stream7 = get_raw_stream(7)
|
| 821 |
+
triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream7)
|
| 822 |
+
del arg10_1
|
| 823 |
+
del arg11_1
|
| 824 |
+
del arg1_1
|
| 825 |
+
del arg3_1
|
| 826 |
+
del arg5_1
|
| 827 |
+
del arg6_1
|
| 828 |
+
del arg9_1
|
| 829 |
+
buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32)
|
| 830 |
+
buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32)
|
| 831 |
+
buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32)
|
| 832 |
+
# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul]
|
| 833 |
+
triton_per_fused_mul_1_xnumel = 32*s37
|
| 834 |
+
stream7 = get_raw_stream(7)
|
| 835 |
+
triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream7)
|
| 836 |
+
del buf1
|
| 837 |
+
ps0 = 128*s37
|
| 838 |
+
buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
|
| 839 |
+
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
|
| 840 |
+
triton_per_fused_2_xnumel = 4096*s37
|
| 841 |
+
stream7 = get_raw_stream(7)
|
| 842 |
+
triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream7)
|
| 843 |
+
del buf0
|
| 844 |
+
del buf2
|
| 845 |
+
del buf5
|
| 846 |
+
del buf7
|
| 847 |
+
return (buf9, buf10, )
|
| 848 |
+
|
| 849 |
+
runner = Runner(partitions=[])
|
| 850 |
+
call = runner.call
|
| 851 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 855 |
+
from torch._dynamo.testing import rand_strided
|
| 856 |
+
from torch._inductor.utils import print_performance
|
| 857 |
+
arg0_1 = 96
|
| 858 |
+
arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 859 |
+
arg2_1 = 96
|
| 860 |
+
arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 861 |
+
arg4_1 = 96
|
| 862 |
+
arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16)
|
| 863 |
+
arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 864 |
+
arg7_1 = 96
|
| 865 |
+
arg8_1 = 96
|
| 866 |
+
arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 867 |
+
arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 868 |
+
arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 869 |
+
arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 870 |
+
arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 871 |
+
arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 872 |
+
arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32)
|
| 873 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1])
|
| 874 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
if __name__ == "__main__":
|
| 878 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 879 |
+
compiled_module_main('None', benchmark_compiled_module)
|
progress/SpecForge/cache/compiled_kernels/6g/826e9651ad1f65a7d666097ac7518bb4b4d3dee1984132523b860dd02b66fff1.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"}
|
progress/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 4096, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_1 = r0_index
|
| 32 |
+
x0 = xindex
|
| 33 |
+
x2 = (xindex % ks0)
|
| 34 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 36 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 37 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 38 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 39 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 40 |
+
tmp6 = float("-inf")
|
| 41 |
+
tmp7 = tmp4 == tmp6
|
| 42 |
+
tmp8 = tmp0 - tmp4
|
| 43 |
+
tmp9 = 0.0
|
| 44 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 45 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 46 |
+
tmp12 = tmp5 * tmp11
|
| 47 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 48 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 49 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp17 = 1.0
|
| 51 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 52 |
+
tmp19 = libdevice.log2(tmp18)
|
| 53 |
+
tmp20 = tmp19 + tmp4
|
| 54 |
+
tmp21 = 0.6931471805599453
|
| 55 |
+
tmp22 = tmp20 * tmp21
|
| 56 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 57 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 58 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
progress/SpecForge/cache/compiled_kernels/6o/c6ovzyfo6vkdwwzou6dtdvw7qjf65ifmzpcoltl2nx2xuluryjcy.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 4096, 'r0_': 128},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 128
|
| 20 |
+
R0_BLOCK: tl.constexpr = 128
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_2 = r0_index
|
| 32 |
+
x0 = (xindex % ks0)
|
| 33 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 34 |
+
x3 = xindex
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32)
|
| 36 |
+
tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32)
|
| 37 |
+
tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last')
|
| 38 |
+
tmp2 = tmp0 * tmp1
|
| 39 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 40 |
+
tmp5 = tl.where(xmask, tmp3, 0)
|
| 41 |
+
tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32)
|
| 42 |
+
tmp7 = tmp6.to(tl.float32)
|
| 43 |
+
tmp9 = 0.6931471805599453
|
| 44 |
+
tmp10 = tmp8 * tmp9
|
| 45 |
+
tmp11 = 1.4426950408889634
|
| 46 |
+
tmp12 = tmp10 * tmp11
|
| 47 |
+
tmp13 = tmp7 - tmp12
|
| 48 |
+
tl.store(out_ptr1 + (x3), tmp13, xmask)
|
progress/SpecForge/cache/compiled_kernels/6o/df004f0eefe2693a59f2bae06581f78ab07b5ce2ec28936911ef13f1152e2ec9.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"}
|
progress/SpecForge/cache/compiled_kernels/6u/c6ulsdn73forgosxqs5bes2cerczsehypg7jodd4snit3gcqp6el.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 32768},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 31232
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = xindex < xnumel
|
| 23 |
+
x0 = xindex
|
| 24 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
| 25 |
+
tmp1 = 0.6931471805599453
|
| 26 |
+
tmp2 = tmp0 * tmp1
|
| 27 |
+
tl.store(out_ptr0 + (x0), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = 1
|
| 130 |
+
stride_kv_idx_h = 1
|
| 131 |
+
stride_kv_idx_m = 1
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'ieee'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/6u/e5329724392dcdd68d88f082f57e8929de539c9aa187c3a314edefdc595437d5.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA"}
|
progress/SpecForge/cache/compiled_kernels/7a/aee791ee3934869dfa55caee4270f116dc979737c2bbcce40af5d5394ccc9ac8.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
|
progress/SpecForge/cache/compiled_kernels/7a/c7a2brsshxp6zz4foe62t5ivwbd2dwr6ytjbhxp22vq2evdotx5z.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 4096},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 20 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 21 |
+
xmask = xindex < xnumel
|
| 22 |
+
x2 = xindex
|
| 23 |
+
x0 = (xindex % ks0)
|
| 24 |
+
x1 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
|
| 26 |
+
tmp1 = 0.6931471805599453
|
| 27 |
+
tmp2 = tmp0 * tmp1
|
| 28 |
+
tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask)
|
progress/SpecForge/cache/compiled_kernels/7a/c7adkdqab5cvqxxnwnn5au23gorh2eg33cfxneh7bzb7untnuvpw.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = ks2
|
| 130 |
+
stride_kv_idx_h = ks3*ks4
|
| 131 |
+
stride_kv_idx_m = ks4
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/7i/c7ijsdt7wst5xe64qslxvdevuoxlscozrh6zigwqavztuz3rptdj.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.persistent_reduction(
|
| 11 |
+
size_hints={'x': 4096, 'r0_': 32},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
|
| 19 |
+
r0_numel = 32
|
| 20 |
+
R0_BLOCK: tl.constexpr = 32
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_index = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
r0_offset = 0
|
| 28 |
+
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 29 |
+
roffset = r0_offset
|
| 30 |
+
rindex = r0_index
|
| 31 |
+
r0_1 = r0_index
|
| 32 |
+
x0 = xindex
|
| 33 |
+
x2 = (xindex % ks0)
|
| 34 |
+
x3 = triton_helpers.div_floor_integer(xindex, ks0)
|
| 35 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 36 |
+
tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0)
|
| 37 |
+
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
|
| 38 |
+
tmp3 = tl.where(xmask, tmp1, float("-inf"))
|
| 39 |
+
tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32)
|
| 40 |
+
tmp6 = float("-inf")
|
| 41 |
+
tmp7 = tmp4 == tmp6
|
| 42 |
+
tmp8 = tmp0 - tmp4
|
| 43 |
+
tmp9 = 0.0
|
| 44 |
+
tmp10 = tl.where(tmp7, tmp9, tmp8)
|
| 45 |
+
tmp11 = libdevice.exp2(tmp10)
|
| 46 |
+
tmp12 = tmp5 * tmp11
|
| 47 |
+
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
|
| 48 |
+
tmp15 = tl.where(xmask, tmp13, 0)
|
| 49 |
+
tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
|
| 50 |
+
tmp17 = 1.0
|
| 51 |
+
tmp18 = tl.where(tmp7, tmp17, tmp16)
|
| 52 |
+
tmp19 = libdevice.log2(tmp18)
|
| 53 |
+
tmp20 = tmp19 + tmp4
|
| 54 |
+
tmp21 = 0.6931471805599453
|
| 55 |
+
tmp22 = tmp20 * tmp21
|
| 56 |
+
tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask)
|
| 57 |
+
tl.store(out_ptr0 + (x0), tmp4, xmask)
|
| 58 |
+
tl.store(out_ptr1 + (x0), tmp16, xmask)
|
progress/SpecForge/cache/compiled_kernels/7i/e1454135615b9b6420e5ef4fe0804f1b1346f398b5afb34dbe8085f0e900c8aa.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "RHE7JXFOJQBB2ESJBNT5OZ3PCTRJ7WSJRB7A2GLRM73N3EI7TWDQ"}
|
progress/SpecForge/cache/compiled_kernels/7n/c7n4jk5r4lsbq62vtrxzouvawlecnbfhy3owedw4ewuid7d56bjs.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=2,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 28 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 29 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 30 |
+
SPLIT_KV : tl.constexpr = 32
|
| 31 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 32 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 34 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 35 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 36 |
+
BLOCK_M : tl.constexpr = 512
|
| 37 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 38 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 39 |
+
BLOCK_N : tl.constexpr = 64
|
| 40 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 41 |
+
USE_TMA : tl.constexpr = False
|
| 42 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 43 |
+
Q = arg_Q
|
| 44 |
+
K = arg_K
|
| 45 |
+
V = arg_V
|
| 46 |
+
M = arg_M
|
| 47 |
+
L = arg_L
|
| 48 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 49 |
+
KV_IDX = arg_KV_IDX
|
| 50 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 51 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 52 |
+
|
| 53 |
+
# Sub notation for this kernel:
|
| 54 |
+
# Q: Query, K: Key, V: Value
|
| 55 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 56 |
+
# M: Number of queries, N: Number of keys/values
|
| 57 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 58 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 59 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 60 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 61 |
+
# (Modifiable) Config options:
|
| 62 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 63 |
+
# TILE_KV: length of each local KV split
|
| 64 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 65 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 66 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 67 |
+
#
|
| 68 |
+
# change of base out of the loop
|
| 69 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 70 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 71 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 72 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 73 |
+
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 75 |
+
#
|
| 76 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 77 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 78 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 79 |
+
#
|
| 80 |
+
#
|
| 81 |
+
# Output: ACC output accumulated across local KV split.
|
| 82 |
+
|
| 83 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 84 |
+
|
| 85 |
+
# Define Q Strides
|
| 86 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1
|
| 87 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 89 |
+
stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1
|
| 90 |
+
stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Z = 1
|
| 94 |
+
ZKV = 1
|
| 95 |
+
HKV = 8
|
| 96 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 97 |
+
HQ = HKV * G
|
| 98 |
+
Q_LEN = ks0
|
| 99 |
+
KV_LEN = ks1
|
| 100 |
+
|
| 101 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 102 |
+
|
| 103 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 104 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 105 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 106 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 107 |
+
|
| 108 |
+
off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV
|
| 109 |
+
off_zkv = off_z % ZKV
|
| 110 |
+
off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV
|
| 111 |
+
off_t = tl.program_id(1).to(INDEX_DTYPE)
|
| 112 |
+
|
| 113 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 114 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 115 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 116 |
+
|
| 117 |
+
K = K + k_offset
|
| 118 |
+
V = V + v_offset
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 124 |
+
sparse_idx_h = off_hkv % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 127 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 128 |
+
|
| 129 |
+
# initialize pointer to m and l
|
| 130 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 131 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 132 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 133 |
+
|
| 134 |
+
# initialize offsets
|
| 135 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 136 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 137 |
+
off_g = tl.arange(0, G) # [G]
|
| 138 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 139 |
+
offs_hq = offs_g + off_hkv * G
|
| 140 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 141 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 142 |
+
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 143 |
+
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 144 |
+
|
| 145 |
+
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
| 146 |
+
stride_block_z, stride_block_h, stride_block_row = 1, 1, 1
|
| 147 |
+
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
| 148 |
+
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1
|
| 149 |
+
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
| 150 |
+
|
| 151 |
+
# Calculate KV blocks that belong this CTA.
|
| 152 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 153 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 154 |
+
|
| 155 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 156 |
+
|
| 157 |
+
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 158 |
+
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
| 159 |
+
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
| 160 |
+
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
| 161 |
+
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
| 162 |
+
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
| 163 |
+
else:
|
| 164 |
+
q = tl.load(Q + q_offset + q_range)
|
| 165 |
+
|
| 166 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 170 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 171 |
+
# Offset the kv_indices tensor by the correct batch and head
|
| 172 |
+
kv_indices = KV_IDX + sparse_idx_hz_offset
|
| 173 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
| 174 |
+
MAX_KV_IDX = 1
|
| 175 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 176 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 177 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 178 |
+
# first kv block we're loading
|
| 179 |
+
|
| 180 |
+
# last valid block according to sparse mask
|
| 181 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 182 |
+
|
| 183 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 184 |
+
|
| 185 |
+
desc_k = None
|
| 186 |
+
desc_v = None
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 190 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 191 |
+
# accumulatd values
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
#offsets
|
| 194 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 195 |
+
off_n,
|
| 196 |
+
#block sparse data
|
| 197 |
+
kv_indices, kv_num_blocks,
|
| 198 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 199 |
+
MATMUL_PRECISION,
|
| 200 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 201 |
+
IS_FULL_BLOCKS=False,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 206 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 207 |
+
# apply mask_mod to them - only score_mod
|
| 208 |
+
if HAS_FULL_BLOCKS:
|
| 209 |
+
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
| 210 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
| 211 |
+
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
| 212 |
+
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
| 213 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
| 214 |
+
indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX)
|
| 215 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 216 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 217 |
+
|
| 218 |
+
# last valid block according to sparse mask
|
| 219 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 220 |
+
|
| 221 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 222 |
+
|
| 223 |
+
acc, l_i, m_i = forward_inner(
|
| 224 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 225 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 226 |
+
# accumulatd values
|
| 227 |
+
acc, l_i, m_i,
|
| 228 |
+
#offsets
|
| 229 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 230 |
+
off_n,
|
| 231 |
+
#block sparse data
|
| 232 |
+
kv_indices, kv_num_blocks,
|
| 233 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 234 |
+
MATMUL_PRECISION,
|
| 235 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 236 |
+
IS_FULL_BLOCKS=True,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 240 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 241 |
+
|
| 242 |
+
M_block_ptr = tl.make_block_ptr(
|
| 243 |
+
base=M + m_offset,
|
| 244 |
+
shape=(G, Q_LEN), # (G, M)
|
| 245 |
+
strides=(stride_mh, stride_mm),
|
| 246 |
+
offsets=(off_hkv*G, 0),
|
| 247 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 248 |
+
order=(1, 0)
|
| 249 |
+
)
|
| 250 |
+
L_block_ptr = tl.make_block_ptr(
|
| 251 |
+
base=L + l_offset,
|
| 252 |
+
shape=(G, Q_LEN), # (G, M)
|
| 253 |
+
strides=(stride_lh, stride_lm),
|
| 254 |
+
offsets=(off_hkv*G, 0),
|
| 255 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 256 |
+
order=(1, 0)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 260 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 261 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 262 |
+
if SAFE_M_BOUNDARY:
|
| 263 |
+
tl.store(M_block_ptr, m_i)
|
| 264 |
+
tl.store(L_block_ptr, l_i)
|
| 265 |
+
else:
|
| 266 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 267 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 268 |
+
|
| 269 |
+
# -- store output
|
| 270 |
+
idx_z = off_z
|
| 271 |
+
idx_t = off_t
|
| 272 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 273 |
+
idx_m = off_m[None, :, None]
|
| 274 |
+
idx_d = offs_vd[None, None, :]
|
| 275 |
+
|
| 276 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 277 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 278 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0
|
| 279 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Utility triton funcs
|
| 283 |
+
@triton.jit
|
| 284 |
+
def get_offset_for_next_block(
|
| 285 |
+
loop_iter, col_indices, total_blocks,
|
| 286 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 287 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 288 |
+
):
|
| 289 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 290 |
+
return BLOCK
|
| 291 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 292 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 293 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 294 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 295 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 296 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 297 |
+
return offset
|
| 298 |
+
|
| 299 |
+
@triton.jit
|
| 300 |
+
def get_bounded_indices(indices, max_len=None):
|
| 301 |
+
return indices % max_len if max_len is not None else indices
|
| 302 |
+
|
| 303 |
+
@triton.jit
|
| 304 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 305 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 306 |
+
return tl.load(block_ptr)
|
| 307 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 308 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 309 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 310 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 311 |
+
else:
|
| 312 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 313 |
+
|
| 314 |
+
@triton.jit
|
| 315 |
+
def load_checked_2d(
|
| 316 |
+
ptr,
|
| 317 |
+
offs_m,
|
| 318 |
+
offs_n,
|
| 319 |
+
stride_m,
|
| 320 |
+
stride_n,
|
| 321 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 322 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 323 |
+
M_LEN: tl.constexpr,
|
| 324 |
+
N_LEN: tl.constexpr,
|
| 325 |
+
):
|
| 326 |
+
# Calculate final pointer if strides are provided
|
| 327 |
+
if stride_m is not None and stride_n is not None:
|
| 328 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 329 |
+
|
| 330 |
+
# Handle all masking cases
|
| 331 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 332 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 333 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 334 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 335 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 336 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 337 |
+
else: # Both divisible
|
| 338 |
+
return tl.load(ptr)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# Common Imports
|
| 342 |
+
@triton.jit
|
| 343 |
+
def forward_block_mn(
|
| 344 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 345 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 346 |
+
# accumulated values
|
| 347 |
+
acc, l_i, m_i,
|
| 348 |
+
# Offsets
|
| 349 |
+
off_z, off_h, offs_m, offs_n,
|
| 350 |
+
# Offsets needed for TMA loads
|
| 351 |
+
kv_start,
|
| 352 |
+
kv_offset,
|
| 353 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 354 |
+
# Strides for K and V
|
| 355 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 356 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 357 |
+
|
| 358 |
+
):
|
| 359 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 360 |
+
PRESCALE_QK : tl.constexpr = False
|
| 361 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 362 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 363 |
+
WRITE_DQ : tl.constexpr = True
|
| 364 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 365 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 366 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 367 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 368 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 369 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 370 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 371 |
+
SPLIT_KV : tl.constexpr = 32
|
| 372 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 373 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 374 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 375 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 376 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 377 |
+
BLOCK_M : tl.constexpr = 512
|
| 378 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 379 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 380 |
+
BLOCK_N : tl.constexpr = 64
|
| 381 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 382 |
+
USE_TMA : tl.constexpr = False
|
| 383 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# -- load k --
|
| 387 |
+
# NB reversed order to since K is transposed
|
| 388 |
+
kv_base_offset = kv_start + kv_offset
|
| 389 |
+
|
| 390 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 391 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 392 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 393 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 394 |
+
|
| 395 |
+
k = tl.trans(k)
|
| 396 |
+
# -- compute qk ---
|
| 397 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
qk *= SM_SCALE
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 402 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 403 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 404 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 405 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 406 |
+
|
| 407 |
+
tmp0 = (qk)
|
| 408 |
+
post_mod_scores = tmp0
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 412 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 413 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 414 |
+
|
| 415 |
+
if not IS_FULL_BLOCKS:
|
| 416 |
+
tmp1 = (m)
|
| 417 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 418 |
+
tmp3 = tmp1 < tmp2
|
| 419 |
+
tmp4 = (n)
|
| 420 |
+
tmp5 = tmp4 <= tmp1
|
| 421 |
+
tmp6 = tmp3 & tmp5
|
| 422 |
+
tmp7 = tmp1 >= tmp2
|
| 423 |
+
tmp8 = tmp4 < tmp2
|
| 424 |
+
tmp9 = tmp7 & tmp8
|
| 425 |
+
tmp10 = tmp8 == 0
|
| 426 |
+
tmp11 = tmp7 & tmp10
|
| 427 |
+
tmp12 = tmp1 - tmp2
|
| 428 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 429 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 430 |
+
tmp15 = tmp4 - tmp2
|
| 431 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 432 |
+
tmp17 = tmp14 == tmp16
|
| 433 |
+
tmp18 = tmp11 & tmp17
|
| 434 |
+
tmp19 = tmp9 | tmp18
|
| 435 |
+
tmp20 = tmp6 | tmp19
|
| 436 |
+
mask_mod_output = tmp20
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 440 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 441 |
+
# apply mask for partially unmasked blocks
|
| 442 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 443 |
+
|
| 444 |
+
if not PRESCALE_QK:
|
| 445 |
+
post_mod_scores *= RCP_LN2
|
| 446 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 447 |
+
|
| 448 |
+
# -- compute scaling constant ---
|
| 449 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 450 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 451 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 452 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 453 |
+
else:
|
| 454 |
+
m_ij_masked = m_ij
|
| 455 |
+
|
| 456 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 457 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 458 |
+
|
| 459 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 460 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 461 |
+
# m_ij
|
| 462 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 463 |
+
# # -- scale and update acc --
|
| 464 |
+
acc = acc * alpha[:, None]
|
| 465 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 466 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 467 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 468 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 469 |
+
|
| 470 |
+
# -- update m_i
|
| 471 |
+
m_i = m_ij
|
| 472 |
+
|
| 473 |
+
return acc, l_i, m_i
|
| 474 |
+
|
| 475 |
+
@triton.jit
|
| 476 |
+
def forward_inner(
|
| 477 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 478 |
+
q, K, V,
|
| 479 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 480 |
+
# accumulated values
|
| 481 |
+
acc, l_i, m_i,
|
| 482 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 483 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 484 |
+
off_z, off_h, offs_m, offs_n,
|
| 485 |
+
# Offsets needed for TMA loads
|
| 486 |
+
kv_start,
|
| 487 |
+
# blocksparse data
|
| 488 |
+
kv_indices, kv_num_blocks,
|
| 489 |
+
# start kv and end kv block
|
| 490 |
+
block_n_start, block_n_end,
|
| 491 |
+
MATMUL_PRECISION,
|
| 492 |
+
# Strides for K and V
|
| 493 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 494 |
+
IS_FULL_BLOCKS,
|
| 495 |
+
):
|
| 496 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 497 |
+
PRESCALE_QK : tl.constexpr = False
|
| 498 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 499 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 500 |
+
WRITE_DQ : tl.constexpr = True
|
| 501 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 502 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 503 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 504 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 505 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 506 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 507 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 508 |
+
SPLIT_KV : tl.constexpr = 32
|
| 509 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 510 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 511 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 512 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 513 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 514 |
+
BLOCK_M : tl.constexpr = 512
|
| 515 |
+
SAFE_M_BOUNDARY : tl.constexpr = False
|
| 516 |
+
SAFE_N_BOUNDARY : tl.constexpr = True
|
| 517 |
+
BLOCK_N : tl.constexpr = 64
|
| 518 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 519 |
+
USE_TMA : tl.constexpr = False
|
| 520 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 524 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 525 |
+
|
| 526 |
+
if PRESCALE_QK:
|
| 527 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 528 |
+
|
| 529 |
+
kv_offset = 0
|
| 530 |
+
|
| 531 |
+
# loop over k, v and update accumulator until block_n_end
|
| 532 |
+
for start_n in range(block_n_start, block_n_end):
|
| 533 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 534 |
+
if IS_DIVISIBLE:
|
| 535 |
+
acc, l_i, m_i = forward_block_mn(
|
| 536 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 537 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 538 |
+
# accumulated values
|
| 539 |
+
acc, l_i, m_i,
|
| 540 |
+
# Offsets
|
| 541 |
+
off_z, off_h, offs_m, offs_n,
|
| 542 |
+
# Offsets needed for TMA loads
|
| 543 |
+
kv_start,
|
| 544 |
+
kv_offset,
|
| 545 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 546 |
+
# Strides for K and V
|
| 547 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 548 |
+
IS_FULL_BLOCKS,
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 552 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 553 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 554 |
+
# to the last block because it's faster a lot.
|
| 555 |
+
acc, l_i, m_i = forward_block_mn(
|
| 556 |
+
arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 557 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 558 |
+
# accumulated values
|
| 559 |
+
acc, l_i, m_i,
|
| 560 |
+
# Offsets
|
| 561 |
+
off_z, off_h, offs_m, offs_n,
|
| 562 |
+
# Offsets needed for TMA loads
|
| 563 |
+
kv_start,
|
| 564 |
+
kv_offset,
|
| 565 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 566 |
+
# Strides for K and V
|
| 567 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 568 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
offset = get_offset_for_next_block(
|
| 574 |
+
start_n, kv_indices, kv_num_blocks,
|
| 575 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
offs_n = offs_n + offset
|
| 579 |
+
kv_offset += offset
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/a3/ca3omjumwqpxxjrgphxuxva3yanssfkbnvrp3buqomyudb2eg4nc.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
|
| 9 |
+
@triton_heuristics.template(
|
| 10 |
+
|
| 11 |
+
num_stages=3,
|
| 12 |
+
num_warps=8,
|
| 13 |
+
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
|
| 14 |
+
inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
|
| 15 |
+
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1):
|
| 19 |
+
PRESCALE_QK : tl.constexpr = False
|
| 20 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 21 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 22 |
+
WRITE_DQ : tl.constexpr = True
|
| 23 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 24 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 25 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 26 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 27 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 28 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 29 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 30 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 31 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 32 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 33 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 34 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 35 |
+
USE_TMA : tl.constexpr = False
|
| 36 |
+
BLOCK_M : tl.constexpr = 128
|
| 37 |
+
BLOCK_N : tl.constexpr = 64
|
| 38 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 39 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 40 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 41 |
+
Q = arg_Q
|
| 42 |
+
K = arg_K
|
| 43 |
+
V = arg_V
|
| 44 |
+
LSE = arg_LSE
|
| 45 |
+
MAX = arg_MAX
|
| 46 |
+
KV_NUM_BLKS = arg_KV_NUM_BLKS
|
| 47 |
+
KV_IDX = arg_KV_IDX
|
| 48 |
+
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
|
| 49 |
+
FULL_KV_IDX = arg_FULL_KV_IDX
|
| 50 |
+
|
| 51 |
+
# Sub notation for this kernel:
|
| 52 |
+
#
|
| 53 |
+
# Q: Query, K: Key, V: Value
|
| 54 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 55 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 56 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 57 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 61 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 62 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 63 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 64 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 65 |
+
#
|
| 66 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 67 |
+
#
|
| 68 |
+
# (Modifiable) Performance tuning options
|
| 69 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 70 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 71 |
+
|
| 72 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 73 |
+
# or involve a numerics vs. perf tradeoff
|
| 74 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 75 |
+
# about 20% more numerical error, but slightly faster.
|
| 76 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 77 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 78 |
+
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
| 79 |
+
# contiguous? If so, we don't need to do an indirect jump for every block
|
| 80 |
+
|
| 81 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 82 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 83 |
+
|
| 84 |
+
# Define strides of inputs
|
| 85 |
+
stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
|
| 86 |
+
stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1
|
| 87 |
+
stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1
|
| 88 |
+
|
| 89 |
+
ZQ = 1
|
| 90 |
+
HQ = 32
|
| 91 |
+
Q_LEN = ks0
|
| 92 |
+
ZKV = 1
|
| 93 |
+
KV_LEN = ks1
|
| 94 |
+
|
| 95 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 96 |
+
|
| 97 |
+
q_start = tl.program_id(0).to(INDEX_DTYPE)
|
| 98 |
+
off_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 99 |
+
off_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 100 |
+
|
| 101 |
+
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
| 102 |
+
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
| 103 |
+
off_zkv = off_zq % ZKV
|
| 104 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 105 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 106 |
+
|
| 107 |
+
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
| 108 |
+
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
| 109 |
+
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
| 110 |
+
|
| 111 |
+
Q = Q + q_offset
|
| 112 |
+
K = K + k_offset
|
| 113 |
+
V = V + v_offset
|
| 114 |
+
|
| 115 |
+
# Setting up the TMA descriptors for Q, K, V
|
| 116 |
+
desc_q = None
|
| 117 |
+
desc_k = None
|
| 118 |
+
desc_v = None
|
| 119 |
+
|
| 120 |
+
SPARSE_Z = 1
|
| 121 |
+
SPARSE_HQ = 1
|
| 122 |
+
|
| 123 |
+
sparse_idx_z = off_zq % SPARSE_Z
|
| 124 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 125 |
+
|
| 126 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 127 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 128 |
+
|
| 129 |
+
stride_kv_num_blks_h = 1
|
| 130 |
+
stride_kv_idx_h = 1
|
| 131 |
+
stride_kv_idx_m = 1
|
| 132 |
+
|
| 133 |
+
# initialize pointer to m and l
|
| 134 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 135 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 136 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
| 137 |
+
|
| 138 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 139 |
+
|
| 140 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 141 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 142 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
| 143 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
| 144 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 146 |
+
q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
| 147 |
+
|
| 148 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 149 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 150 |
+
# both score_mod and mask_mod to it
|
| 151 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 152 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 153 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 154 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# K and V pointers will be passed directly to forward_inner
|
| 158 |
+
|
| 159 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
acc, l_i, m_i = forward_inner(
|
| 163 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 164 |
+
q, K, V,
|
| 165 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 166 |
+
acc, l_i, m_i,
|
| 167 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 168 |
+
kv_start,
|
| 169 |
+
kv_indices, kv_num_blocks,
|
| 170 |
+
0, block_n_end,
|
| 171 |
+
MATMUL_PRECISION,
|
| 172 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 173 |
+
IS_FULL_BLOCKS=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 177 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 178 |
+
# apply mask_mod to them - only score_mod
|
| 179 |
+
if HAS_FULL_BLOCKS:
|
| 180 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 181 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 182 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 183 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 184 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 185 |
+
# K and V pointers will be passed directly to forward_inner
|
| 186 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 187 |
+
|
| 188 |
+
acc, l_i, m_i = forward_inner(
|
| 189 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 190 |
+
q, K, V,
|
| 191 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 192 |
+
acc, l_i, m_i,
|
| 193 |
+
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
| 194 |
+
kv_start,
|
| 195 |
+
kv_indices, kv_num_blocks,
|
| 196 |
+
0, block_n_end,
|
| 197 |
+
MATMUL_PRECISION,
|
| 198 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 199 |
+
IS_FULL_BLOCKS=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# [Note] Handle fully masked out rows:
|
| 204 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 205 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 206 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 207 |
+
|
| 208 |
+
acc = acc / l_i[:, None]
|
| 209 |
+
idx_zq = tl.program_id(1).to(INDEX_DTYPE)
|
| 210 |
+
idx_hq = tl.program_id(2).to(INDEX_DTYPE)
|
| 211 |
+
idx_m = offs_m[:, None].to(INDEX_DTYPE)
|
| 212 |
+
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
|
| 213 |
+
|
| 214 |
+
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
| 215 |
+
|
| 216 |
+
tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
|
| 217 |
+
xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
|
| 218 |
+
tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask)
|
| 219 |
+
|
| 220 |
+
if OUTPUT_LOGSUMEXP:
|
| 221 |
+
off_hz = off_zq * HQ + off_hq
|
| 222 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 223 |
+
lse = m_i + tl.math.log2(l_i)
|
| 224 |
+
if IS_DIVISIBLE:
|
| 225 |
+
tl.store(l_ptrs, lse)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 228 |
+
|
| 229 |
+
if OUTPUT_MAX:
|
| 230 |
+
off_hz = off_zq * HQ + off_hq
|
| 231 |
+
max_ptrs = MAX + off_hz * Q_LEN + offs_m
|
| 232 |
+
if IS_DIVISIBLE:
|
| 233 |
+
tl.store(max_ptrs, m_i)
|
| 234 |
+
else:
|
| 235 |
+
tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Utility triton funcs
|
| 239 |
+
@triton.jit
|
| 240 |
+
def get_offset_for_next_block(
|
| 241 |
+
loop_iter, col_indices, total_blocks,
|
| 242 |
+
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
| 243 |
+
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
| 244 |
+
):
|
| 245 |
+
if BLOCKS_ARE_CONTIGUOUS:
|
| 246 |
+
return BLOCK
|
| 247 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 248 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 249 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 250 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 251 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 252 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 253 |
+
return offset
|
| 254 |
+
|
| 255 |
+
@triton.jit
|
| 256 |
+
def get_bounded_indices(indices, max_len=None):
|
| 257 |
+
return indices % max_len if max_len is not None else indices
|
| 258 |
+
|
| 259 |
+
@triton.jit
|
| 260 |
+
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
| 261 |
+
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 262 |
+
return tl.load(block_ptr)
|
| 263 |
+
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
| 264 |
+
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
| 265 |
+
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
| 266 |
+
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
| 267 |
+
else:
|
| 268 |
+
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
| 269 |
+
|
| 270 |
+
@triton.jit
|
| 271 |
+
def load_checked_2d(
|
| 272 |
+
ptr,
|
| 273 |
+
offs_m,
|
| 274 |
+
offs_n,
|
| 275 |
+
stride_m,
|
| 276 |
+
stride_n,
|
| 277 |
+
IS_DIVISIBLE_M: tl.constexpr,
|
| 278 |
+
IS_DIVISIBLE_N: tl.constexpr,
|
| 279 |
+
M_LEN: tl.constexpr,
|
| 280 |
+
N_LEN: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
# Calculate final pointer if strides are provided
|
| 283 |
+
if stride_m is not None and stride_n is not None:
|
| 284 |
+
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
| 285 |
+
|
| 286 |
+
# Handle all masking cases
|
| 287 |
+
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 288 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
|
| 289 |
+
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
| 290 |
+
return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
|
| 291 |
+
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
| 292 |
+
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
| 293 |
+
else: # Both divisible
|
| 294 |
+
return tl.load(ptr)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Common Imports
|
| 298 |
+
@triton.jit
|
| 299 |
+
def forward_block_mn(
|
| 300 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 301 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 302 |
+
# accumulated values
|
| 303 |
+
acc, l_i, m_i,
|
| 304 |
+
# Offsets
|
| 305 |
+
off_z, off_h, offs_m, offs_n,
|
| 306 |
+
# Offsets needed for TMA loads
|
| 307 |
+
kv_start,
|
| 308 |
+
kv_offset,
|
| 309 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 310 |
+
# Strides for K and V
|
| 311 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 312 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 313 |
+
|
| 314 |
+
):
|
| 315 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 316 |
+
PRESCALE_QK : tl.constexpr = False
|
| 317 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 318 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 319 |
+
WRITE_DQ : tl.constexpr = True
|
| 320 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 321 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 322 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 323 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 324 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 325 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 326 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 327 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 328 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 329 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 330 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 331 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 332 |
+
USE_TMA : tl.constexpr = False
|
| 333 |
+
BLOCK_M : tl.constexpr = 128
|
| 334 |
+
BLOCK_N : tl.constexpr = 64
|
| 335 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 336 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 337 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# -- load k --
|
| 341 |
+
# NB reversed order to since K is transposed
|
| 342 |
+
kv_base_offset = kv_start + kv_offset
|
| 343 |
+
|
| 344 |
+
# Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
|
| 345 |
+
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
| 346 |
+
offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
|
| 347 |
+
k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
| 348 |
+
|
| 349 |
+
k = tl.trans(k)
|
| 350 |
+
# -- compute qk ---
|
| 351 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 352 |
+
if not PRESCALE_QK:
|
| 353 |
+
qk *= SM_SCALE
|
| 354 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 355 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 356 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 357 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 358 |
+
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 359 |
+
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
| 360 |
+
|
| 361 |
+
tmp0 = (qk)
|
| 362 |
+
post_mod_scores = tmp0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 366 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 367 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 368 |
+
|
| 369 |
+
if not IS_FULL_BLOCKS:
|
| 370 |
+
tmp1 = (m)
|
| 371 |
+
tmp2 = tl.full([1], 0, tl.int32)
|
| 372 |
+
tmp3 = tmp1 < tmp2
|
| 373 |
+
tmp4 = (n)
|
| 374 |
+
tmp5 = tmp4 <= tmp1
|
| 375 |
+
tmp6 = tmp3 & tmp5
|
| 376 |
+
tmp7 = tmp1 >= tmp2
|
| 377 |
+
tmp8 = tmp4 < tmp2
|
| 378 |
+
tmp9 = tmp7 & tmp8
|
| 379 |
+
tmp10 = tmp8 == 0
|
| 380 |
+
tmp11 = tmp7 & tmp10
|
| 381 |
+
tmp12 = tmp1 - tmp2
|
| 382 |
+
tmp13 = tl.full([1], 16, tl.int32)
|
| 383 |
+
tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13)
|
| 384 |
+
tmp15 = tmp4 - tmp2
|
| 385 |
+
tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13)
|
| 386 |
+
tmp17 = tmp14 == tmp16
|
| 387 |
+
tmp18 = tmp11 & tmp17
|
| 388 |
+
tmp19 = tmp9 | tmp18
|
| 389 |
+
tmp20 = tmp6 | tmp19
|
| 390 |
+
mask_mod_output = tmp20
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 394 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
| 395 |
+
# apply mask for partially unmasked blocks
|
| 396 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 397 |
+
|
| 398 |
+
if not PRESCALE_QK:
|
| 399 |
+
post_mod_scores *= RCP_LN2
|
| 400 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 401 |
+
|
| 402 |
+
# -- compute scaling constant ---
|
| 403 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 404 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 405 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 406 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 407 |
+
else:
|
| 408 |
+
m_ij_masked = m_ij
|
| 409 |
+
|
| 410 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 411 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 412 |
+
|
| 413 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 414 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 415 |
+
# m_ij
|
| 416 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 417 |
+
# # -- scale and update acc --
|
| 418 |
+
acc = acc * alpha[:, None]
|
| 419 |
+
# Calculate offsets for V loading - reuse kv_base_offset from K loading
|
| 420 |
+
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
| 421 |
+
v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
| 422 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 423 |
+
|
| 424 |
+
# -- update m_i
|
| 425 |
+
m_i = m_ij
|
| 426 |
+
|
| 427 |
+
return acc, l_i, m_i
|
| 428 |
+
|
| 429 |
+
@triton.jit
|
| 430 |
+
def forward_inner(
|
| 431 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 432 |
+
q, K, V,
|
| 433 |
+
desc_k, desc_v, Q_LEN, KV_LEN,
|
| 434 |
+
# accumulated values
|
| 435 |
+
acc, l_i, m_i,
|
| 436 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 437 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 438 |
+
off_z, off_h, offs_m, offs_n,
|
| 439 |
+
# Offsets needed for TMA loads
|
| 440 |
+
kv_start,
|
| 441 |
+
# blocksparse data
|
| 442 |
+
kv_indices, kv_num_blocks,
|
| 443 |
+
# start kv and end kv block
|
| 444 |
+
block_n_start, block_n_end,
|
| 445 |
+
MATMUL_PRECISION,
|
| 446 |
+
# Strides for K and V
|
| 447 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 448 |
+
IS_FULL_BLOCKS,
|
| 449 |
+
):
|
| 450 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 451 |
+
PRESCALE_QK : tl.constexpr = False
|
| 452 |
+
ROWS_GUARANTEED_SAFE : tl.constexpr = False
|
| 453 |
+
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
|
| 454 |
+
WRITE_DQ : tl.constexpr = True
|
| 455 |
+
OUTPUT_LOGSUMEXP : tl.constexpr = True
|
| 456 |
+
OUTPUT_MAX : tl.constexpr = False
|
| 457 |
+
FLOAT32_PRECISION : tl.constexpr = 'tf32'
|
| 458 |
+
IS_DIVISIBLE : tl.constexpr = False
|
| 459 |
+
SM_SCALE : tl.constexpr = 0.08838834764831845
|
| 460 |
+
GQA_SHARED_HEADS : tl.constexpr = 4
|
| 461 |
+
HAS_FULL_BLOCKS : tl.constexpr = True
|
| 462 |
+
QK_HEAD_DIM : tl.constexpr = 128
|
| 463 |
+
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 464 |
+
V_HEAD_DIM : tl.constexpr = 128
|
| 465 |
+
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
|
| 466 |
+
SAFE_HEAD_DIM : tl.constexpr = True
|
| 467 |
+
USE_TMA : tl.constexpr = False
|
| 468 |
+
BLOCK_M : tl.constexpr = 128
|
| 469 |
+
BLOCK_N : tl.constexpr = 64
|
| 470 |
+
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
|
| 471 |
+
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
|
| 472 |
+
INDEX_DTYPE : tl.constexpr = tl.int32
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 476 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 477 |
+
|
| 478 |
+
if PRESCALE_QK:
|
| 479 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 480 |
+
|
| 481 |
+
kv_offset = 0
|
| 482 |
+
|
| 483 |
+
# loop over k, v and update accumulator until block_n_end
|
| 484 |
+
for start_n in range(block_n_start, block_n_end):
|
| 485 |
+
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
| 486 |
+
if IS_DIVISIBLE:
|
| 487 |
+
acc, l_i, m_i = forward_block_mn(
|
| 488 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 489 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 490 |
+
# accumulated values
|
| 491 |
+
acc, l_i, m_i,
|
| 492 |
+
# Offsets
|
| 493 |
+
off_z, off_h, offs_m, offs_n,
|
| 494 |
+
# Offsets needed for TMA loads
|
| 495 |
+
kv_start,
|
| 496 |
+
kv_offset,
|
| 497 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 498 |
+
# Strides for K and V
|
| 499 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 500 |
+
IS_FULL_BLOCKS,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 504 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 505 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 506 |
+
# to the last block because it's faster a lot.
|
| 507 |
+
acc, l_i, m_i = forward_block_mn(
|
| 508 |
+
arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1,
|
| 509 |
+
q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
|
| 510 |
+
# accumulated values
|
| 511 |
+
acc, l_i, m_i,
|
| 512 |
+
# Offsets
|
| 513 |
+
off_z, off_h, offs_m, offs_n,
|
| 514 |
+
# Offsets needed for TMA loads
|
| 515 |
+
kv_start,
|
| 516 |
+
kv_offset,
|
| 517 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 518 |
+
# Strides for K and V
|
| 519 |
+
stride_kk, stride_kn, stride_vn, stride_vk,
|
| 520 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
offset = get_offset_for_next_block(
|
| 526 |
+
start_n, kv_indices, kv_num_blocks,
|
| 527 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
offs_n = offs_n + offset
|
| 531 |
+
kv_offset += offset
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
return acc, l_i, m_i
|
progress/SpecForge/cache/compiled_kernels/ad/4a5ce6c582fc1ef37d4cf3003d603da533264d4c59e6e0cb171d0b7490f32260.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
|
progress/SpecForge/cache/compiled_kernels/ad/b7e210dbfa93430d766b2fdeddfba773a52faf0eca21a68632a8853e9c3ecaf4.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"}
|