diff --git a/SpecForge-ext/cache/compiled_kernels/27/178fbe80a6655cd928415af01c18f906aac88a91f6ad870b6e019e505e41d8d6.best_config b/SpecForge-ext/cache/compiled_kernels/27/178fbe80a6655cd928415af01c18f906aac88a91f6ad870b6e019e505e41d8d6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7275f384fdb6ca9cbceb6da8e9f0f1fc1a30db44 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/27/178fbe80a6655cd928415af01c18f906aac88a91f6ad870b6e019e505e41d8d6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 140, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/27/9e94f85bd7b5310932ad7debf393fd211c0f4a83a0599bb42899fad47199226a.best_config b/SpecForge-ext/cache/compiled_kernels/27/9e94f85bd7b5310932ad7debf393fd211c0f4a83a0599bb42899fad47199226a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/27/9e94f85bd7b5310932ad7debf393fd211c0f4a83a0599bb42899fad47199226a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py b/SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py new file mode 100644 index 0000000000000000000000000000000000000000..be58b17bc42bc67cee40cd406576e724b7bb651b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/27/c27s4qoyzyvf54snkgtay3lqlnoj3bgphotvv5xwczxe6bqovure.py b/SpecForge-ext/cache/compiled_kernels/27/c27s4qoyzyvf54snkgtay3lqlnoj3bgphotvv5xwczxe6bqovure.py new file mode 100644 index 0000000000000000000000000000000000000000..46a4728c39d59fddd9a3e92563e4f8ad60738cdf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/27/c27s4qoyzyvf54snkgtay3lqlnoj3bgphotvv5xwczxe6bqovure.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/3j/319ae573ab7247866e4ff0749ebbc205378f8b72492d64376005ae12fdcc85cb.best_config b/SpecForge-ext/cache/compiled_kernels/3j/319ae573ab7247866e4ff0749ebbc205378f8b72492d64376005ae12fdcc85cb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c2d9b36c5180887fa413aa1eb230c04dc216dd00 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3j/319ae573ab7247866e4ff0749ebbc205378f8b72492d64376005ae12fdcc85cb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/3j/c3j47dekusw3y4mohtk5v36cc6fso3wdtqn5oqjwew3yy3exjo76.py b/SpecForge-ext/cache/compiled_kernels/3j/c3j47dekusw3y4mohtk5v36cc6fso3wdtqn5oqjwew3yy3exjo76.py new file mode 100644 index 0000000000000000000000000000000000000000..82b9d505064e4e025abfade8f2c296e5c51f0b2d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3j/c3j47dekusw3y4mohtk5v36cc6fso3wdtqn5oqjwew3yy3exjo76.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/3r/c3rfwo25yzbkbl5er7svhpxhxxjdbre5zoxeq5wwcwlsvq2puinx.py b/SpecForge-ext/cache/compiled_kernels/3r/c3rfwo25yzbkbl5er7svhpxhxxjdbre5zoxeq5wwcwlsvq2puinx.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfe54d8436e090c631bb95d25c88b3baed85fe7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3r/c3rfwo25yzbkbl5er7svhpxhxxjdbre5zoxeq5wwcwlsvq2puinx.py @@ -0,0 +1,322 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/df/cdfb6cgenzsju5cqvy4244xh4xidniyeznvkubvdg2mg6d5oc6xt.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gg/cgggk6pegregqt4lolln3yxfp6wzahy6vf2ocae3vbpohfif7mtz.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9 + stream7 = get_raw_stream(7) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream7) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream7 = get_raw_stream(7) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream7) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 2 + primals_11 = 32 + primals_13 = 8 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:7', dtype=torch.int64) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/3z/c3zdyaemmekwelpoed5jduslt3o4gp6avp6it2wx3udu2z3kxz65.py b/SpecForge-ext/cache/compiled_kernels/3z/c3zdyaemmekwelpoed5jduslt3o4gp6avp6it2wx3udu2z3kxz65.py new file mode 100644 index 0000000000000000000000000000000000000000..c50dd3063b70a2e84b1d32e70e7bfa3e40b7eb12 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3z/c3zdyaemmekwelpoed5jduslt3o4gp6avp6it2wx3udu2z3kxz65.py @@ -0,0 +1,309 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rv/crvxbztkr372fgdn7bgrud22s3wmd2isidwo4ek4hldn56tuv2dj.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kt/ckt6ylj5dotkksawawvin5yyeytmo5tcvmqpulhfstvqh3aecfft.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_14 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:1" = PlaceHolder[target=primals_14] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_14, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s25*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_14, (s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream1 = get_raw_stream(1) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream1) + del primals_12 + ps1 = s24*s25 + buf1 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s25*s48*s9 + stream1 = get_raw_stream(1) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_14, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream1) + del primals_14 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s25, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:1', dtype=torch.int64) + primals_9 = 1 + primals_10 = 2 + primals_11 = 32 + primals_12 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_13 = 8 + primals_14 = rand_strided((2, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/43/c43xt3pebupiz26noaypbobqj4gw2z5njubsnsy3la7enx2j3exz.py b/SpecForge-ext/cache/compiled_kernels/43/c43xt3pebupiz26noaypbobqj4gw2z5njubsnsy3la7enx2j3exz.py new file mode 100644 index 0000000000000000000000000000000000000000..e807648d829792ee8dbe2f7718824f5b94750b77 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/43/c43xt3pebupiz26noaypbobqj4gw2z5njubsnsy3la7enx2j3exz.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:4" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:4" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xg/cxgourummzwsux6r2gxe7ifvqpdhpgvgbs36tkitfwpr24b4gcvt.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:4" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:4" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream4 = get_raw_stream(4) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream4) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream4 = get_raw_stream(4) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream4) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:4', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/46/7bb099c5cd896693d7710d6358df16b56a62f466c1207186e1ec6fa6aaeb5653.best_config b/SpecForge-ext/cache/compiled_kernels/46/7bb099c5cd896693d7710d6358df16b56a62f466c1207186e1ec6fa6aaeb5653.best_config new file mode 100644 index 0000000000000000000000000000000000000000..1e5a79796def180206ef96ecac567ed1428c9073 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/46/7bb099c5cd896693d7710d6358df16b56a62f466c1207186e1ec6fa6aaeb5653.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 63, "triton_cache_hash": "C3FCZCDEMCLSFODWXLEH5MRAQRWLOTRP4SAQURVAE7BPHZSTV2WQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py b/SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb5e9617f98767679c6e95ed04e524c1d01db64 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5d/c5dxklofifxzswxjhdjvko4ncyrk6vkfrbohhy3eg5kffm63zqjg.py b/SpecForge-ext/cache/compiled_kernels/5d/c5dxklofifxzswxjhdjvko4ncyrk6vkfrbohhy3eg5kffm63zqjg.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7f668f80d92731bb3a4d228dddbf121f545f2a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5d/c5dxklofifxzswxjhdjvko4ncyrk6vkfrbohhy3eg5kffm63zqjg.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) diff --git a/SpecForge-ext/cache/compiled_kernels/5i/c5ijpx5gd3tgnruo2ufacghy5ivgjwoj6s4fhr7c2advvuhujqou.py b/SpecForge-ext/cache/compiled_kernels/5i/c5ijpx5gd3tgnruo2ufacghy5ivgjwoj6s4fhr7c2advvuhujqou.py new file mode 100644 index 0000000000000000000000000000000000000000..8b68a43e7d5904fc5b4269bb17008365866b33d8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5i/c5ijpx5gd3tgnruo2ufacghy5ivgjwoj6s4fhr7c2advvuhujqou.py @@ -0,0 +1,334 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5a/c5aj4zurklc4jhvqpnfghi3vz6khjwkqqeoe4icjpmybdwcujxn6.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %buf0 +triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ug/cug6unv7ylx7cgtwxj6q5dppff2io2k4qf3fhtoe6a2mcfi5dzu5.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:1" = PlaceHolder[target=buf0] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %sum_1 +triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sg/csg2r6gcuw5453tnkx7v65zysasesetlrx733ekbslnhgjntjrkm.py +# Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %primals_7 : Tensor "bf16[s33][1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1" = PlaceHolder[target=rsqrt] +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:1" = PlaceHolder[target=sum_2] +# %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {}) +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {}) +# %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {}) +# %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {}) +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {}) +# %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {}) +# %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {}) +# %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {}) +# %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {}) +# %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {}) +# %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {}) +# %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {}) +# %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {}) +# %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {}) +# %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_7, (s33, ), (1, )) + assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1)) + assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33 + triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream1) + buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + stream1 = get_raw_stream(1) + triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream1) + del buf0 + buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream1) + del primals_4 + del primals_7 + del rsqrt + del tangents_1 + return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_6 = 840433664 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = rand_strided((4096, ), (1, ), device='cuda:1', dtype=torch.bfloat16) + rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/5n/c5n5tizmmgs4cmiupzpopubn6t7eviwt42e3csvt472h63vjwmbu.py b/SpecForge-ext/cache/compiled_kernels/5n/c5n5tizmmgs4cmiupzpopubn6t7eviwt42e3csvt472h63vjwmbu.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7132610b2ca9404496fc39936fe32b78bd2a92 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5n/c5n5tizmmgs4cmiupzpopubn6t7eviwt42e3csvt472h63vjwmbu.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/6a/c6akpy3glququp6suktd5kfns5jol46fmjfl5brlavs7c4zodqhi.py b/SpecForge-ext/cache/compiled_kernels/6a/c6akpy3glququp6suktd5kfns5jol46fmjfl5brlavs7c4zodqhi.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9fe273f730dfa7475773ff810bb2acbe01b62c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6a/c6akpy3glququp6suktd5kfns5jol46fmjfl5brlavs7c4zodqhi.py @@ -0,0 +1,354 @@ +# AOT ID: ['15_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/x7/cx7fsejzde6zv22nl7w3xpjhybajijgeetsfqi733ibymkptkdrq.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg4_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:3" = PlaceHolder[target=arg4_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg4_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2o/c2otr5mtbf3tmh4ztmfjn6qv6r3raha22m4sr5h4kaplsk53xtg4.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:3" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:3" = PlaceHolder[target=argmax_1] +# %arg5_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:3" = PlaceHolder[target=arg5_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg5_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cj/ccjx73kwqy3z57a3fjxor5ma5tgytixf7htmrtqxzyfleohcklv4.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg7_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:3" = PlaceHolder[target=arg7_1] +# %sum_1 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg7_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s0 = arg3_1 + s14 = arg6_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg5_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg7_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream3 = get_raw_stream(3) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream3) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream3 = get_raw_stream(3) + triton_red_fused_argmax_1.run(arg4_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream3) + del arg4_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream3 = get_raw_stream(3) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg5_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream3) + del arg5_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream3 = get_raw_stream(3) + triton_red_fused_clamp_min_div_sum_3.run(arg7_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream3) + del arg7_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2014 + arg1_1 = rand_strided((2, 2014, 32000), (64448000, 32000, 1), device='cuda:3', dtype=torch.bfloat16) + arg2_1 = 64672000 + arg3_1 = 32000 + arg4_1 = rand_strided((2, 2014, 32000), (64672000, 32000, 1), device='cuda:3', dtype=torch.float32) + arg5_1 = rand_strided((2, 2014, 1), (2014, 1, 1), device='cuda:3', dtype=torch.int64) + arg6_1 = 2014 + arg7_1 = rand_strided((2, 2014, 1), (2014, 1, 1), device='cuda:3', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py b/SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py new file mode 100644 index 0000000000000000000000000000000000000000..1c96bc3387e8d6c012bfd3e75fcab00a5ce7ce7a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/73/c73fwyrbq77x2xei6vgy66r34rmoo6pilk6kf2iqehy6oksjmbwj.py b/SpecForge-ext/cache/compiled_kernels/73/c73fwyrbq77x2xei6vgy66r34rmoo6pilk6kf2iqehy6oksjmbwj.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5cd79435cc666569d37c169d830fd04a91e608 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/73/c73fwyrbq77x2xei6vgy66r34rmoo6pilk6kf2iqehy6oksjmbwj.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hj/chjn2h2lagxtgealz3aitqmfnksszmnt7q4hnsw5vu6risac6dmq.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream3) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/bj/cbj72m23cmcn2yjoxrp4vabc2f76gw727jcpbi4y5oidokqenki5.py b/SpecForge-ext/cache/compiled_kernels/bj/cbj72m23cmcn2yjoxrp4vabc2f76gw727jcpbi4y5oidokqenki5.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8221323262298eda8334c068bb3496859dab71 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bj/cbj72m23cmcn2yjoxrp4vabc2f76gw727jcpbi4y5oidokqenki5.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ce/ccez434a7hzyympuosxgkqmu5zncaqowipase2enwlehm3k7igny.py b/SpecForge-ext/cache/compiled_kernels/ce/ccez434a7hzyympuosxgkqmu5zncaqowipase2enwlehm3k7igny.py new file mode 100644 index 0000000000000000000000000000000000000000..3decc360cdf90ccfce2af3492138b3bbe74f33a5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ce/ccez434a7hzyympuosxgkqmu5zncaqowipase2enwlehm3k7igny.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/cm/c07273381756209821a51449b0970c31551d44d464333fbb852c9fc655362c46.best_config b/SpecForge-ext/cache/compiled_kernels/cm/c07273381756209821a51449b0970c31551d44d464333fbb852c9fc655362c46.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a337a719c6503c8dcbad0c427c4a5067600d0bd0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/cm/c07273381756209821a51449b0970c31551d44d464333fbb852c9fc655362c46.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/cr/ccr5s7nffy4cqd7a3lcq3cnv2prruzwzc7chchf776jguuqqh5bc.py b/SpecForge-ext/cache/compiled_kernels/cr/ccr5s7nffy4cqd7a3lcq3cnv2prruzwzc7chchf776jguuqqh5bc.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5fdd87f7d7188318428b3a19c60f300b0c7cbf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/cr/ccr5s7nffy4cqd7a3lcq3cnv2prruzwzc7chchf776jguuqqh5bc.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/d4/cd4qgy6v3vbg74qytdbsmdpamjzb6kuwcsiu7yfpml4f7zxhf4j3.py b/SpecForge-ext/cache/compiled_kernels/d4/cd4qgy6v3vbg74qytdbsmdpamjzb6kuwcsiu7yfpml4f7zxhf4j3.py new file mode 100644 index 0000000000000000000000000000000000000000..659e552c42e3626aedc6a383eb103373c7a98733 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/d4/cd4qgy6v3vbg74qytdbsmdpamjzb6kuwcsiu7yfpml4f7zxhf4j3.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zw/czwh6tgkq6scdstgzueb3goqqnllndikoasj2i2iehu2qyvoccwt.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 151936][311164928, 151936, 1]cuda:6" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (8, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 16384, 151936, stream=stream6) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 151936), (311164928, 151936, 1), device='cuda:6', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:6', dtype=torch.bool) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/e4/9676ab333d44c7c3eec122b806e4fc2028468bcd55a94115e7272e322515d58c.best_config b/SpecForge-ext/cache/compiled_kernels/e4/9676ab333d44c7c3eec122b806e4fc2028468bcd55a94115e7272e322515d58c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cd9795263343a19ee8f06cf527807cd2d9adfee5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e4/9676ab333d44c7c3eec122b806e4fc2028468bcd55a94115e7272e322515d58c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/e6/ce65awkeaxcxjqfa27pcogsy3sjyxwzxjt3w2rte76m7izgybp2s.py b/SpecForge-ext/cache/compiled_kernels/e6/ce65awkeaxcxjqfa27pcogsy3sjyxwzxjt3w2rte76m7izgybp2s.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5370859a5813361a5f083facd661107c521e75 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e6/ce65awkeaxcxjqfa27pcogsy3sjyxwzxjt3w2rte76m7izgybp2s.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ea/ceahttlkg35qey3ao6gw65rzzv3bop5xwrthhogt6nyvyw3rece5.py b/SpecForge-ext/cache/compiled_kernels/ea/ceahttlkg35qey3ao6gw65rzzv3bop5xwrthhogt6nyvyw3rece5.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a1bd5437c7a2acb5874ea91684a45ea486312a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ea/ceahttlkg35qey3ao6gw65rzzv3bop5xwrthhogt6nyvyw3rece5.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/eg/32c642d9b91cf1b4fe91f745e14ee86eccc7f5783759ca318976f5af47c474c9.best_config b/SpecForge-ext/cache/compiled_kernels/eg/32c642d9b91cf1b4fe91f745e14ee86eccc7f5783759ca318976f5af47c474c9.best_config new file mode 100644 index 0000000000000000000000000000000000000000..128251849e0d90499e31f76727557122755609e2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eg/32c642d9b91cf1b4fe91f745e14ee86eccc7f5783759ca318976f5af47c474c9.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py b/SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7fd1b8d26426aa87bd952099437a0924110c07 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/eg/cegtc7d6fywtdtf2rerfuwwwn7fajohhh2ltqlvjjvdguxabbva4.py b/SpecForge-ext/cache/compiled_kernels/eg/cegtc7d6fywtdtf2rerfuwwwn7fajohhh2ltqlvjjvdguxabbva4.py new file mode 100644 index 0000000000000000000000000000000000000000..9a768c792501f9958b67c44a3d9d6885c07fb98e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eg/cegtc7d6fywtdtf2rerfuwwwn7fajohhh2ltqlvjjvdguxabbva4.py @@ -0,0 +1,416 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/p3/cp3lt4qtmnlmp6kb7cx5zc6bshlrxlbfed2c4ciyoiapxknraax3.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s3, 32000][32000*s3, 32000, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/eu/ceui6qrb2t3lmzs3ljrqtcomt4b2q6svzo24j6mmryaiovr6kp7y.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[8, s3, 32000][s71, 32000, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:0" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:0" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[8, s3, 1][s3, 1, 1]cuda:0" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xe/cxe74qazmcwxkyh3xlgupaetbeksmhlptcogpgxu7tfvr4arcob6.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg6_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/az/cazc4elakae7tgyuygha6gaxmfo4ouj4mtb6kxylbj7524jvkqaz.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_2] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (8, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (8, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (8, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (8, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 8*s3 + stream0 = get_raw_stream(0) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream0) + del arg1_1 + buf1 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 8*s3 + stream0 = get_raw_stream(0) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream0) + del arg3_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 4*s3 + stream0 = get_raw_stream(0) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf3, s3, 2, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream0) + del arg4_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + triton_red_fused_sum_3_r0_numel = 4*s14 + stream0 = get_raw_stream(0) + triton_red_fused_sum_3.run(arg6_1, buf5, s14, 2, triton_red_fused_sum_3_r0_numel, stream=stream0) + del arg6_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream0 = get_raw_stream(0) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream0) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2009 + arg1_1 = rand_strided((8, 2009, 32000), (64288000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = 64512000 + arg3_1 = rand_strided((8, 2009, 32000), (64512000, 32000, 1), device='cuda:0', dtype=torch.float32) + arg4_1 = rand_strided((8, 2009, 1), (2009, 1, 1), device='cuda:0', dtype=torch.int64) + arg5_1 = 2009 + arg6_1 = rand_strided((8, 2009, 1), (2009, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/el/celb4xosanyf3m2sx6v3t54w4bgkx65m4lb2newu7nggikw6jbxj.py b/SpecForge-ext/cache/compiled_kernels/el/celb4xosanyf3m2sx6v3t54w4bgkx65m4lb2newu7nggikw6jbxj.py new file mode 100644 index 0000000000000000000000000000000000000000..daee6f274fabe68ba53b1f715c54e717632418d6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/el/celb4xosanyf3m2sx6v3t54w4bgkx65m4lb2newu7nggikw6jbxj.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/et/cet6lrlwcthdi3by3ttnab2z245l4q55x7tvdilkic6xqjfjlixg.py b/SpecForge-ext/cache/compiled_kernels/et/cet6lrlwcthdi3by3ttnab2z245l4q55x7tvdilkic6xqjfjlixg.py new file mode 100644 index 0000000000000000000000000000000000000000..aea208c3bff3e3beedaf05c712511fce5e866084 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/et/cet6lrlwcthdi3by3ttnab2z245l4q55x7tvdilkic6xqjfjlixg.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) diff --git a/SpecForge-ext/cache/compiled_kernels/ey/cey3ar6s7f2t62buescu5cctxdhf6hmbv3ps5d3tmh235oaj3fj6.py b/SpecForge-ext/cache/compiled_kernels/ey/cey3ar6s7f2t62buescu5cctxdhf6hmbv3ps5d3tmh235oaj3fj6.py new file mode 100644 index 0000000000000000000000000000000000000000..402044816e92b86067fa38d298c8d4def3c3f593 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ey/cey3ar6s7f2t62buescu5cctxdhf6hmbv3ps5d3tmh235oaj3fj6.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ey/ceyifglcwq5k7zog6faauufd7zk5fsacgjqk43m6vpya73dy3l62.py b/SpecForge-ext/cache/compiled_kernels/ey/ceyifglcwq5k7zog6faauufd7zk5fsacgjqk43m6vpya73dy3l62.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc7bfc87939c785c4808eefb038ee5b3fd393e4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ey/ceyifglcwq5k7zog6faauufd7zk5fsacgjqk43m6vpya73dy3l62.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wr/cwrt3cdfiri2z4jso4afypedtru4cdebpo556yzgrqawlufswk26.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[8][1]cuda:6" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:6"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hw/chwm44jdqtovypwqknevqvz2d2xrazceb4ci2erooz4tahlocvzv.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rk/crkxpwiwhzkvun7i5d2pegofthyfijn5wygwnaev3twwlrbuojqe.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_11 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4y/c4yxgotihoxpn6o5xa4jvkcy7shlgnyv44u6dpm5e746f6dwg7oe.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:6" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:6" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream6 = get_raw_stream(6) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream6) + del arg0_1 + buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream6) + buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream6) + buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream6) + del buf0 + buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream6) + del buf8 + buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream6) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, ), (1, ), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ey/ceyzf3pcewvjtqjk6jiokovxh2sqktcak7dttp7wu3pugjxaoweu.py b/SpecForge-ext/cache/compiled_kernels/ey/ceyzf3pcewvjtqjk6jiokovxh2sqktcak7dttp7wu3pugjxaoweu.py new file mode 100644 index 0000000000000000000000000000000000000000..0924187f7d547b3a4e646584cd4e0cfa5acac65b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ey/ceyzf3pcewvjtqjk6jiokovxh2sqktcak7dttp7wu3pugjxaoweu.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ey/f8e6f482f3185b2937177b6d0b6caa60104c3cdb0966b9b98cfda24132197a8c.best_config b/SpecForge-ext/cache/compiled_kernels/ey/f8e6f482f3185b2937177b6d0b6caa60104c3cdb0966b9b98cfda24132197a8c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a570e8d663ff6e600f50df05a811c859065ec3c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ey/f8e6f482f3185b2937177b6d0b6caa60104c3cdb0966b9b98cfda24132197a8c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/fg/cfg7sytfzjcof3mvqa6lexwoxlaj3zogf2jn2jbgerew6ytuhqkm.py b/SpecForge-ext/cache/compiled_kernels/fg/cfg7sytfzjcof3mvqa6lexwoxlaj3zogf2jn2jbgerew6ytuhqkm.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c6a4865d8e5ccf241d2a427a1b8466ac23b175 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fg/cfg7sytfzjcof3mvqa6lexwoxlaj3zogf2jn2jbgerew6ytuhqkm.py @@ -0,0 +1,37 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/fg/cfgdj37atk5pvqz7oags4dv3jc65exjssmxxu3c4srgtfjnh7kgw.py b/SpecForge-ext/cache/compiled_kernels/fg/cfgdj37atk5pvqz7oags4dv3jc65exjssmxxu3c4srgtfjnh7kgw.py new file mode 100644 index 0000000000000000000000000000000000000000..acb2c2ead93c852ac0ca71b093bfda4d7889f254 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fg/cfgdj37atk5pvqz7oags4dv3jc65exjssmxxu3c4srgtfjnh7kgw.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/fg/cfgilsqr4dj7cpcripi7zlobhu3rqxlfddiwwrzuy5xlumnjw5lh.py b/SpecForge-ext/cache/compiled_kernels/fg/cfgilsqr4dj7cpcripi7zlobhu3rqxlfddiwwrzuy5xlumnjw5lh.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9f66dbd03cc52f13854bf6e942ff19f755b222 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fg/cfgilsqr4dj7cpcripi7zlobhu3rqxlfddiwwrzuy5xlumnjw5lh.py @@ -0,0 +1,37 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/fo/cfooe7ht55q5jhejzd3zyb3g5v64cvxjohkxeadllgnjxgiwo52v.py b/SpecForge-ext/cache/compiled_kernels/fo/cfooe7ht55q5jhejzd3zyb3g5v64cvxjohkxeadllgnjxgiwo52v.py new file mode 100644 index 0000000000000000000000000000000000000000..bad13ae7c037c509863f3b4539494c9b90622ccf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fo/cfooe7ht55q5jhejzd3zyb3g5v64cvxjohkxeadllgnjxgiwo52v.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ft/cftsee2mvtzxgy2wgchwunv4g4rgysco4n3gsokqlal6zoqbmnub.py b/SpecForge-ext/cache/compiled_kernels/ft/cftsee2mvtzxgy2wgchwunv4g4rgysco4n3gsokqlal6zoqbmnub.py new file mode 100644 index 0000000000000000000000000000000000000000..fb206e8277f332f67f9b92e162536a0e7ee53bec --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ft/cftsee2mvtzxgy2wgchwunv4g4rgysco4n3gsokqlal6zoqbmnub.py @@ -0,0 +1,303 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:1" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gk/cgkqxvbgd6bawj2pp2icrhzkfuzcptxodfjpshgozv6ysjvxo65g.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[2, 2048, 32000][65760000, 32000, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:1" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:1" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:1" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %sum_1 : Tensor "i64[][]cuda:1" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:1" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[2, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (2, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream1 = get_raw_stream(1) + triton_red_fused_argmax_0.run(arg0_1, buf0, 4096, 32000, stream=stream1) + del arg0_1 + buf1 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream1 = get_raw_stream(1) + triton_red_fused_argmax_1.run(arg1_1, buf1, 4096, 32000, stream=stream1) + del arg1_1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream1 = get_raw_stream(1) + triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, arg3_1, buf4, 1, 4096, stream=stream1) + del arg2_1 + del arg3_1 + del buf0 + del buf1 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:1', dtype=torch.bfloat16) + arg1_1 = rand_strided((2, 2048, 32000), (65760000, 32000, 1), device='cuda:1', dtype=torch.float32) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.int64) + arg3_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/g3/cg3kczutozttzr55b4vjq62nto7vv2qnqb553mhae4gtgepz7vkj.py b/SpecForge-ext/cache/compiled_kernels/g3/cg3kczutozttzr55b4vjq62nto7vv2qnqb553mhae4gtgepz7vkj.py new file mode 100644 index 0000000000000000000000000000000000000000..1c03ddd6c93947903332b09bc1ddfa98f16345b2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/g3/cg3kczutozttzr55b4vjq62nto7vv2qnqb553mhae4gtgepz7vkj.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/hq/chqc3is7lze3bdohf7qrowyfetyhjquhgfsobrnoq7hbrmp6ohdx.py b/SpecForge-ext/cache/compiled_kernels/hq/chqc3is7lze3bdohf7qrowyfetyhjquhgfsobrnoq7hbrmp6ohdx.py new file mode 100644 index 0000000000000000000000000000000000000000..a3abf354f603142469d7042a1623a2deaea6459d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hq/chqc3is7lze3bdohf7qrowyfetyhjquhgfsobrnoq7hbrmp6ohdx.py @@ -0,0 +1,334 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %buf0 +triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/yf/cyft7sialepriw6eujulaxpi57qlrafkmp4k2kjwzw4noh23ddz6.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:6" = PlaceHolder[target=buf0] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %sum_1 +triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gn/cgnmjxikvi5ulcyj3uozif3le5hd26kw2kjhkcbhupqgudqi3bwn.py +# Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %primals_7 : Tensor "bf16[s33][1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6" = PlaceHolder[target=rsqrt] +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:6" = PlaceHolder[target=sum_2] +# %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {}) +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {}) +# %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {}) +# %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {}) +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {}) +# %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {}) +# %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {}) +# %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {}) +# %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {}) +# %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {}) +# %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {}) +# %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {}) +# %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {}) +# %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {}) +# %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_7, (s33, ), (1, )) + assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1)) + assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33 + triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream6) + buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream6) + del buf0 + buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream6) + del primals_4 + del primals_7 + del rsqrt + del tangents_1 + return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_6 = 840433664 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((4096, ), (1, ), device='cuda:6', dtype=torch.bfloat16) + rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/hq/chqstdcrwlggtj2cbkjjgtxib54f5qfcipeqs3k27hifudgguv7t.py b/SpecForge-ext/cache/compiled_kernels/hq/chqstdcrwlggtj2cbkjjgtxib54f5qfcipeqs3k27hifudgguv7t.py new file mode 100644 index 0000000000000000000000000000000000000000..2b503c91e70305775f53710c2b30f72181b202e8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hq/chqstdcrwlggtj2cbkjjgtxib54f5qfcipeqs3k27hifudgguv7t.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/hv/chvj5h3adlnuxifatrhlirixthstwv5pzbxvuapjby5cz2npck63.py b/SpecForge-ext/cache/compiled_kernels/hv/chvj5h3adlnuxifatrhlirixthstwv5pzbxvuapjby5cz2npck63.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed764641b9c0616cb1a9fc9103086cf9daca8c5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hv/chvj5h3adlnuxifatrhlirixthstwv5pzbxvuapjby5cz2npck63.py @@ -0,0 +1,99 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ks/36cfdc5c4318d8e35940f3471fa9a8cde8092c3294a90679819920b4db6ea3bb.best_config b/SpecForge-ext/cache/compiled_kernels/ks/36cfdc5c4318d8e35940f3471fa9a8cde8092c3294a90679819920b4db6ea3bb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7d56ea7451f6ff3ceffec392bc015b86ab20533e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ks/36cfdc5c4318d8e35940f3471fa9a8cde8092c3294a90679819920b4db6ea3bb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "4UWYNBR3KPWQGNAZ5LIIRE7YAZWTQP4CP3JS6GOSLWYDF5K7WTAA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ks/cksdatp7sjl5kfr5pxvwrbjelhvz35c35rvym5wgbvhrovwd5isa.py b/SpecForge-ext/cache/compiled_kernels/ks/cksdatp7sjl5kfr5pxvwrbjelhvz35c35rvym5wgbvhrovwd5isa.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0edd4a99f79689276c94f4b418f4f12516f6bc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ks/cksdatp7sjl5kfr5pxvwrbjelhvz35c35rvym5wgbvhrovwd5isa.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ks/ckske6cm4vgoewu6hpzmhdk7yxnddtnqlrbts7nwodsrty3grim2.py b/SpecForge-ext/cache/compiled_kernels/ks/ckske6cm4vgoewu6hpzmhdk7yxnddtnqlrbts7nwodsrty3grim2.py new file mode 100644 index 0000000000000000000000000000000000000000..9d030cafdd30060777b962197576e87d600ff79c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ks/ckske6cm4vgoewu6hpzmhdk7yxnddtnqlrbts7nwodsrty3grim2.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/kx/ckxgrh6l45wgzd3gv6uy3i3z4hrfyct6es6sh2fdnsi6q4hicyjs.py b/SpecForge-ext/cache/compiled_kernels/kx/ckxgrh6l45wgzd3gv6uy3i3z4hrfyct6es6sh2fdnsi6q4hicyjs.py new file mode 100644 index 0000000000000000000000000000000000000000..1e81bccf481daf63cd20c79aeb12b62d42854c92 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kx/ckxgrh6l45wgzd3gv6uy3i3z4hrfyct6es6sh2fdnsi6q4hicyjs.py @@ -0,0 +1,168 @@ +# AOT ID: ['10_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/py/cpyon4zgaupgqfwtaeshxummq5taahi4k54ubix2xgrrupxyugiq.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul_2 +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:5" = PlaceHolder[target=argmax] +# %arg2_1 : Tensor "b8[151936][1]cuda:5" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:5" = PlaceHolder[target=arg3_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# %index : Tensor "b8[2, s14][s14, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {}) +# return %argmax,%mul_2 +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s24 = arg0_1 + arg1_1_size = arg1_1.size() + s14 = arg1_1_size[1] + assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1)) + assert_size_stride(arg2_1, (151936, ), (1, )) + assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream5) + del arg1_1 + del arg2_1 + del arg3_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1569 + arg1_1 = rand_strided((2, 1569, 151936), (238387584, 151936, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = rand_strided((151936, ), (1, ), device='cuda:5', dtype=torch.bool) + arg3_1 = rand_strided((2, 1569, 1), (1569, 1, 1), device='cuda:5', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/kx/ckxtdzhg3azhdxeooy2uushwzka4sz2hzjpq5dulk2g2jjweqr6b.py b/SpecForge-ext/cache/compiled_kernels/kx/ckxtdzhg3azhdxeooy2uushwzka4sz2hzjpq5dulk2g2jjweqr6b.py new file mode 100644 index 0000000000000000000000000000000000000000..c521668981a6811a0b0a00047a6aa3a8f493725a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kx/ckxtdzhg3azhdxeooy2uushwzka4sz2hzjpq5dulk2g2jjweqr6b.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/l4/858a6c2e50b765fa4386efe0007977eb588741281d2c492d383f481ceaa46b11.best_config b/SpecForge-ext/cache/compiled_kernels/l4/858a6c2e50b765fa4386efe0007977eb588741281d2c492d383f481ceaa46b11.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cd9795263343a19ee8f06cf527807cd2d9adfee5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/l4/858a6c2e50b765fa4386efe0007977eb588741281d2c492d383f481ceaa46b11.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/l4/cl45ilp34erze7maypgnzjiaafh3lmzk67erw2irtjg7fhwhyggv.py b/SpecForge-ext/cache/compiled_kernels/l4/cl45ilp34erze7maypgnzjiaafh3lmzk67erw2irtjg7fhwhyggv.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dc897e7f23ab36d91c6365d930443140fa47f7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/l4/cl45ilp34erze7maypgnzjiaafh3lmzk67erw2irtjg7fhwhyggv.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/l4/cl4mfrugc46xgxleh7kty7kuchbazhslnsdosx2m3jtt7ezzmr56.py b/SpecForge-ext/cache/compiled_kernels/l4/cl4mfrugc46xgxleh7kty7kuchbazhslnsdosx2m3jtt7ezzmr56.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9f462e14b73ec64a1f12db58f991f0781e0fbd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/l4/cl4mfrugc46xgxleh7kty7kuchbazhslnsdosx2m3jtt7ezzmr56.py @@ -0,0 +1,63 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) diff --git a/SpecForge-ext/cache/compiled_kernels/l7/cl75kwldwzlkrgr6kieudxjqv25kwbinmibnpwwjjqmp3ym5fpdd.py b/SpecForge-ext/cache/compiled_kernels/l7/cl75kwldwzlkrgr6kieudxjqv25kwbinmibnpwwjjqmp3ym5fpdd.py new file mode 100644 index 0000000000000000000000000000000000000000..0030c37e2b47767a927665d8243eed177e04a2da --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/l7/cl75kwldwzlkrgr6kieudxjqv25kwbinmibnpwwjjqmp3ym5fpdd.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/l7/cl7uoyo4r2qj6bwjgyfp2qoiszzxw4xmxolbpem6a4fujmllivos.py b/SpecForge-ext/cache/compiled_kernels/l7/cl7uoyo4r2qj6bwjgyfp2qoiszzxw4xmxolbpem6a4fujmllivos.py new file mode 100644 index 0000000000000000000000000000000000000000..0abeb9709c1fc895af9f57f3098cc0d7fe494426 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/l7/cl7uoyo4r2qj6bwjgyfp2qoiszzxw4xmxolbpem6a4fujmllivos.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/m7/cm7u3olama3gox426hxhxixqvzhslez5o7pvi4bnehh2g4ww6k6i.py b/SpecForge-ext/cache/compiled_kernels/m7/cm7u3olama3gox426hxhxixqvzhslez5o7pvi4bnehh2g4ww6k6i.py new file mode 100644 index 0000000000000000000000000000000000000000..032480481750c4717a0dd986043db64b3d5b40c9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/m7/cm7u3olama3gox426hxhxixqvzhslez5o7pvi4bnehh2g4ww6k6i.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/mb/431d38c090b7287d184f2906df3dcec76d8ceff390e6163b8c2ad17f9731b0aa.best_config b/SpecForge-ext/cache/compiled_kernels/mb/431d38c090b7287d184f2906df3dcec76d8ceff390e6163b8c2ad17f9731b0aa.best_config new file mode 100644 index 0000000000000000000000000000000000000000..37707241555f35a01f7e4a693e0cda27ae37aab0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/431d38c090b7287d184f2906df3dcec76d8ceff390e6163b8c2ad17f9731b0aa.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 22, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/mb/cmb2yw2commhj2fugargtx6yqb3h6hzeysb2z7hhv474fxgiudkh.py b/SpecForge-ext/cache/compiled_kernels/mb/cmb2yw2commhj2fugargtx6yqb3h6hzeysb2z7hhv474fxgiudkh.py new file mode 100644 index 0000000000000000000000000000000000000000..d22197965182c7b56645ef7ba1a0c7a6e1b40798 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/cmb2yw2commhj2fugargtx6yqb3h6hzeysb2z7hhv474fxgiudkh.py @@ -0,0 +1,71 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + s21 = arg0_1 + assert_size_stride(arg1_1, (1, 1, 40980, 128), (5245440, 5245440, 128, 1)) + assert_size_stride(arg2_1, (1, 1, 40980, 128), (5245440, 5245440, 128, 1)) + return (reinterpret_tensor(arg1_1, (1, 1, s21, 128), (5245440, 5245440, 128, 1), 0), reinterpret_tensor(arg2_1, (1, 1, s21, 128), (5245440, 5245440, 128, 1), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2048 + arg1_1 = rand_strided((1, 1, 40980, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 1, 40980, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/mb/cmb5zlldkrfleyhckhiomhpfsvkfkllotbopalddwkdfhlizveux.py b/SpecForge-ext/cache/compiled_kernels/mb/cmb5zlldkrfleyhckhiomhpfsvkfkllotbopalddwkdfhlizveux.py new file mode 100644 index 0000000000000000000000000000000000000000..6dea38cc1cf5821d3fbcef3ed4e9721e1396da9e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/cmb5zlldkrfleyhckhiomhpfsvkfkllotbopalddwkdfhlizveux.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mb/cmbhgz4c2hwbce6pchwlcnkxfrh55hxi5c4dp2sn4lys5xivdvad.py b/SpecForge-ext/cache/compiled_kernels/mb/cmbhgz4c2hwbce6pchwlcnkxfrh55hxi5c4dp2sn4lys5xivdvad.py new file mode 100644 index 0000000000000000000000000000000000000000..82edcbc670ab0c6deb7f1f1f1068037cb6f52d95 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/cmbhgz4c2hwbce6pchwlcnkxfrh55hxi5c4dp2sn4lys5xivdvad.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mb/cmbnzt5yo46dnfaaleroif4wljfg34wf7wuaecftgfewxdgwxqrv.py b/SpecForge-ext/cache/compiled_kernels/mb/cmbnzt5yo46dnfaaleroif4wljfg34wf7wuaecftgfewxdgwxqrv.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1a8e41eab36a7807ca0095c104d5310558c76c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/cmbnzt5yo46dnfaaleroif4wljfg34wf7wuaecftgfewxdgwxqrv.py @@ -0,0 +1,41 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mb/cmboruk2gyuhq43degftaqzb2abxergkmetbzmcgprn7eynqywpe.py b/SpecForge-ext/cache/compiled_kernels/mb/cmboruk2gyuhq43degftaqzb2abxergkmetbzmcgprn7eynqywpe.py new file mode 100644 index 0000000000000000000000000000000000000000..b722387f8bc4408ff1d7accfe2f3f2b95e1a0143 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/cmboruk2gyuhq43degftaqzb2abxergkmetbzmcgprn7eynqywpe.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mb/cmbsmrz2evoeh27rokw4xjaliz3z5dpfnexko6iss4qbyu6gl24y.py b/SpecForge-ext/cache/compiled_kernels/mb/cmbsmrz2evoeh27rokw4xjaliz3z5dpfnexko6iss4qbyu6gl24y.py new file mode 100644 index 0000000000000000000000000000000000000000..090e07b556ddb0f3e37948755bf686afde4f2fb7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mb/cmbsmrz2evoeh27rokw4xjaliz3z5dpfnexko6iss4qbyu6gl24y.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jc/cjckdkumcnlkvhcgcfckom4kb3kkdpks5eouyhcpnwscklfm3o54.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:0" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:0" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream0 = get_raw_stream(0) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream0) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/md/cmdahcjbg6d3ontcfrzacjt2vey3sp4m2ipzxufeziixbxgozupy.py b/SpecForge-ext/cache/compiled_kernels/md/cmdahcjbg6d3ontcfrzacjt2vey3sp4m2ipzxufeziixbxgozupy.py new file mode 100644 index 0000000000000000000000000000000000000000..1121bad26e1aa384760746bade773e49192dc696 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/md/cmdahcjbg6d3ontcfrzacjt2vey3sp4m2ipzxufeziixbxgozupy.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5p/c5pbkg5eq64emuv25ukki7a5dxvn2p2sh6jeiwb6b54tbidps5w7.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[8][1]cuda:6" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream6) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:6', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/md/cmdanr55qnd72zjkhza6zmry7bkzsyvzxqfecorr4qihht5do2aj.py b/SpecForge-ext/cache/compiled_kernels/md/cmdanr55qnd72zjkhza6zmry7bkzsyvzxqfecorr4qihht5do2aj.py new file mode 100644 index 0000000000000000000000000000000000000000..08b6c3346970b93854be93b9165e19981af865e8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/md/cmdanr55qnd72zjkhza6zmry7bkzsyvzxqfecorr4qihht5do2aj.py @@ -0,0 +1,354 @@ +# AOT ID: ['15_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gp/cgpsfzswzdda3xqtuzsvn66yc4hwpr62mjit23ulsdznaiexdimr.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg4_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:6" = PlaceHolder[target=arg4_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg4_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:6" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:6" = PlaceHolder[target=argmax_1] +# %arg5_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg5_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/64/c64aw6tdhc533bdu4lu2kexzu7e3rgjk5xeentmkjen77ksnc56t.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg7_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:6" = PlaceHolder[target=arg7_1] +# %sum_1 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg7_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s0 = arg3_1 + s14 = arg6_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg5_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg7_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream6 = get_raw_stream(6) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream6) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream6 = get_raw_stream(6) + triton_red_fused_argmax_1.run(arg4_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream6) + del arg4_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream6 = get_raw_stream(6) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg5_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream6) + del arg5_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream6 = get_raw_stream(6) + triton_red_fused_clamp_min_div_sum_3.run(arg7_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream6) + del arg7_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1488 + arg1_1 = rand_strided((2, 1488, 32000), (47616000, 32000, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 47840000 + arg3_1 = 32000 + arg4_1 = rand_strided((2, 1488, 32000), (47840000, 32000, 1), device='cuda:6', dtype=torch.float32) + arg5_1 = rand_strided((2, 1488, 1), (1488, 1, 1), device='cuda:6', dtype=torch.int64) + arg6_1 = 1488 + arg7_1 = rand_strided((2, 1488, 1), (1488, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/mj/cmj7cw6kmumq5dhcjj5z77edfzandxlaftganyllgmz7p3fqeubv.py b/SpecForge-ext/cache/compiled_kernels/mj/cmj7cw6kmumq5dhcjj5z77edfzandxlaftganyllgmz7p3fqeubv.py new file mode 100644 index 0000000000000000000000000000000000000000..e34b01895ebffb128d7925e88152d2af7ef17960 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mj/cmj7cw6kmumq5dhcjj5z77edfzandxlaftganyllgmz7p3fqeubv.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mj/cmjlfojpnpm5jni2ravb3komgycjc5mn3sbp2hi3ttso25z44mlc.py b/SpecForge-ext/cache/compiled_kernels/mj/cmjlfojpnpm5jni2ravb3komgycjc5mn3sbp2hi3ttso25z44mlc.py new file mode 100644 index 0000000000000000000000000000000000000000..d5152e1ebd27480c84d596094a1fef226dd4890d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mj/cmjlfojpnpm5jni2ravb3komgycjc5mn3sbp2hi3ttso25z44mlc.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mo/c6752addd23576471873dac02327581ecb2ce343109ada1c797255e7612dfec0.best_config b/SpecForge-ext/cache/compiled_kernels/mo/c6752addd23576471873dac02327581ecb2ce343109ada1c797255e7612dfec0.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cd9795263343a19ee8f06cf527807cd2d9adfee5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mo/c6752addd23576471873dac02327581ecb2ce343109ada1c797255e7612dfec0.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/mo/cmod6f73obbfcmomsopzfccdedg3pixsdgmsddb3365u65yr5grc.py b/SpecForge-ext/cache/compiled_kernels/mo/cmod6f73obbfcmomsopzfccdedg3pixsdgmsddb3365u65yr5grc.py new file mode 100644 index 0000000000000000000000000000000000000000..febec335d382050ab4be934fad2182d325e339e7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mo/cmod6f73obbfcmomsopzfccdedg3pixsdgmsddb3365u65yr5grc.py @@ -0,0 +1,63 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) diff --git a/SpecForge-ext/cache/compiled_kernels/mr/cmrf526mm6su5xqnmne5okkpp4fxut73afx62yxlvlmbr6yjqxen.py b/SpecForge-ext/cache/compiled_kernels/mr/cmrf526mm6su5xqnmne5okkpp4fxut73afx62yxlvlmbr6yjqxen.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc447774794ac8f07ad58fcd4d3c561846ceaab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mr/cmrf526mm6su5xqnmne5okkpp4fxut73afx62yxlvlmbr6yjqxen.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mw/cmw2hjnuubs2eh7cuc54pem6cjhaz4jgplmqlhrsxfkzljxf7ndg.py b/SpecForge-ext/cache/compiled_kernels/mw/cmw2hjnuubs2eh7cuc54pem6cjhaz4jgplmqlhrsxfkzljxf7ndg.py new file mode 100644 index 0000000000000000000000000000000000000000..d0cce53bc49135beee8b8df7880259620b238d6d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mw/cmw2hjnuubs2eh7cuc54pem6cjhaz4jgplmqlhrsxfkzljxf7ndg.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/mw/cmw5mlntlt7o73p24outkvtp73w3ylg6pk6fbqshalpowjpvoh47.py b/SpecForge-ext/cache/compiled_kernels/mw/cmw5mlntlt7o73p24outkvtp73w3ylg6pk6fbqshalpowjpvoh47.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6d634393bf82f2d4f2eb3b55c98932af14f746 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mw/cmw5mlntlt7o73p24outkvtp73w3ylg6pk6fbqshalpowjpvoh47.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/nj/cnjetfv56plc7qtm6fmop33lhpl3aynjulxdzpaemvs5mq3mbwy6.py b/SpecForge-ext/cache/compiled_kernels/nj/cnjetfv56plc7qtm6fmop33lhpl3aynjulxdzpaemvs5mq3mbwy6.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd3cbcc112a20f895816b1fbf87d809b8221e92 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nj/cnjetfv56plc7qtm6fmop33lhpl3aynjulxdzpaemvs5mq3mbwy6.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/o2/co2y6jknd5elxfam4xxcqiwymrtcpe4mjfkclsyxgznl443h7uak.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:0" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/p6/cp66nvdwdzgxajxp2yjtqapnwidpmfnzcyyalh6z5w6f6lf3aoej.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream0 = get_raw_stream(0) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream0) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream0 = get_raw_stream(0) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream0) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:0', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/nj/cnjktwj7h4iwx4zghbum5atne46yt4ce4t5jnkkvyag35pn7glnh.py b/SpecForge-ext/cache/compiled_kernels/nj/cnjktwj7h4iwx4zghbum5atne46yt4ce4t5jnkkvyag35pn7glnh.py new file mode 100644 index 0000000000000000000000000000000000000000..197f8716a8896fe6cccdc97ed0a4b62b39709267 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nj/cnjktwj7h4iwx4zghbum5atne46yt4ce4t5jnkkvyag35pn7glnh.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/nj/e30035defbb2193e5ffa35928795cfe3d4d46a830167ee06e0c906dc55514d95.best_config b/SpecForge-ext/cache/compiled_kernels/nj/e30035defbb2193e5ffa35928795cfe3d4d46a830167ee06e0c906dc55514d95.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b5fe0bd195e2afa4eb939871edb76221f1e8606e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nj/e30035defbb2193e5ffa35928795cfe3d4d46a830167ee06e0c906dc55514d95.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "E2MI47QNGZ2SJDA3U3EKHN7H3EYRAANF6T7N5SFT2CZJYNBAWCNQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/nr/182616ec61c3af932968e8d9773123f96c77b74ccf109f0c311b233fae7a1859.best_config b/SpecForge-ext/cache/compiled_kernels/nr/182616ec61c3af932968e8d9773123f96c77b74ccf109f0c311b233fae7a1859.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d3b58d3f4d2d57923b7dd1b74f7ebedcc091842a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nr/182616ec61c3af932968e8d9773123f96c77b74ccf109f0c311b233fae7a1859.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 15, "triton_cache_hash": "S3UH64TOYTN473KAATRMGKZ5SLQ46EZYJVPR6TIL7QNMYCB3MSMA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/nr/cnrsai53e7ialdlmhahe4zhhr3uynctmotdwm2ifjluqipqewdbf.py b/SpecForge-ext/cache/compiled_kernels/nr/cnrsai53e7ialdlmhahe4zhhr3uynctmotdwm2ifjluqipqewdbf.py new file mode 100644 index 0000000000000000000000000000000000000000..607d77190a897a74dc7f84172ba360f224e97797 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nr/cnrsai53e7ialdlmhahe4zhhr3uynctmotdwm2ifjluqipqewdbf.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/nw/cnwbaridllvb5ehfoe22t5rapeg7phedfner2dyilwt2r5xa43b4.py b/SpecForge-ext/cache/compiled_kernels/nw/cnwbaridllvb5ehfoe22t5rapeg7phedfner2dyilwt2r5xa43b4.py new file mode 100644 index 0000000000000000000000000000000000000000..b49863add90b7b53885274f236db5280c9b7f2a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nw/cnwbaridllvb5ehfoe22t5rapeg7phedfner2dyilwt2r5xa43b4.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6j/c6jzxztdxbjv5b23nfmgzgtizqp77h7aeak5j2jukmz3roqeiw3k.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/oq/coq4vpmjynhsbr3eh73i7stji3nesrs4kq7nxoqo55gux322a7qi.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:6"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ms/cmss4vkiru5swhxy33hnse2fxrcngpiw7ozbdu4mwbcfqlcznlyx.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:6" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/li/climzgziptsjlz4tsuxacvt6muyoeemo7l2bxcyotfjy3vij6hbt.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:6" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:6" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6}) +# %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/n4/cn44zhfsqswlfcaweodtohqt7stdnvkic4sqxo4katmawv3urke5.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:6" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5u/c5ucf2it6fwhmru32vpy2xmwkk4vg7ppnoxmhnkf6ifpb2h2tjsl.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/35/c35huqp6ngzh67kt32kuxoqpghc32fstv4zogcouzabdxxwta3sl.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream6) + buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream6) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream6) + del arg2_1 + buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream6) + buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream6) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream6) + del buf4 + buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream6) + del buf12 + buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream6) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream6) + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream6) + del buf21 + buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream6) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream6) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream6) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1488 + arg1_1 = 1488 + arg2_1 = rand_strided((2, ), (1, ), device='cuda:6', dtype=torch.int64) + arg3_1 = 1488 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/nw/cnwg5f7d7jo4twbzynnp7xfazo4khvh62gkneggr5aray7kujit5.py b/SpecForge-ext/cache/compiled_kernels/nw/cnwg5f7d7jo4twbzynnp7xfazo4khvh62gkneggr5aray7kujit5.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3dd8f508951ef59d7bcc7477fa3970c2229ccb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nw/cnwg5f7d7jo4twbzynnp7xfazo4khvh62gkneggr5aray7kujit5.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/pl/cplxdj2pdqlunap7iibjfdmmxdbawko34cbmdsgn2wmglrgs4wjc.py b/SpecForge-ext/cache/compiled_kernels/pl/cplxdj2pdqlunap7iibjfdmmxdbawko34cbmdsgn2wmglrgs4wjc.py new file mode 100644 index 0000000000000000000000000000000000000000..6f801680585815fb4e40e241b55901d1ab62ac6c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pl/cplxdj2pdqlunap7iibjfdmmxdbawko34cbmdsgn2wmglrgs4wjc.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/pq/cpq2hqv3mnpai4be3mf3dku5lztddbhtoqcpf4xue66xl73sbmgx.py b/SpecForge-ext/cache/compiled_kernels/pq/cpq2hqv3mnpai4be3mf3dku5lztddbhtoqcpf4xue66xl73sbmgx.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa35c6349b10a6db6c8216430bb75df5bd953a1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pq/cpq2hqv3mnpai4be3mf3dku5lztddbhtoqcpf4xue66xl73sbmgx.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/pq/cpquoebytalyakh5xbcgqhrouuzyvugczrc2quxh5vnumd4vms3g.py b/SpecForge-ext/cache/compiled_kernels/pq/cpquoebytalyakh5xbcgqhrouuzyvugczrc2quxh5vnumd4vms3g.py new file mode 100644 index 0000000000000000000000000000000000000000..9d971245ab40d145a59d1a246b885fa43a66ac37 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pq/cpquoebytalyakh5xbcgqhrouuzyvugczrc2quxh5vnumd4vms3g.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/pt/cptuyaoxfo74byxdyulpqmh6x3zm6rnmsbqjg64y6nwvav2oltqd.py b/SpecForge-ext/cache/compiled_kernels/pt/cptuyaoxfo74byxdyulpqmh6x3zm6rnmsbqjg64y6nwvav2oltqd.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff09d5d1544407908d935754a9a9a989f53e7f3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pt/cptuyaoxfo74byxdyulpqmh6x3zm6rnmsbqjg64y6nwvav2oltqd.py @@ -0,0 +1,410 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ll/clljyrbqq46ev7waq2poicg2u2bcfndqh2qfifloogsqahenydfv.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:6" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/e5/ce5roatsezodqf5qpd32hkspt4fcfiqo7jobom7pzhsftuk27gji.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[8, 2048, 32000][65760000, 32000, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zb/czbkusdbmiy6kikq4qod7mdx6wj5hskghcdn5gr4z6hcdy4v3nbz.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg2_1] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uo/cuoflv2dnlrrnqtf32tqz435pzpw3hvrmamzms4siwsvxeljqc7k.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg3_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %sum_2 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zf/czfjpc2bmajr63w5p55xld3zlnib2cxdzyx3n2uan6r324ujkv52.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (8, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream6 = get_raw_stream(6) + triton_red_fused_argmax_0.run(arg0_1, buf0, 16384, 32000, stream=stream6) + del arg0_1 + buf1 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream6 = get_raw_stream(6) + triton_red_fused_argmax_1.run(arg1_1, buf1, 16384, 32000, stream=stream6) + del arg1_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + stream6 = get_raw_stream(6) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, buf3, 2, 8192, stream=stream6) + del arg2_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + stream6 = get_raw_stream(6) + triton_red_fused_sum_3.run(arg3_1, buf5, 2, 8192, stream=stream6) + del arg3_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream6 = get_raw_stream(6) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream6) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:6', dtype=torch.bfloat16) + arg1_1 = rand_strided((8, 2048, 32000), (65760000, 32000, 1), device='cuda:6', dtype=torch.float32) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64) + arg3_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/py/cpyon4zgaupgqfwtaeshxummq5taahi4k54ubix2xgrrupxyugiq.py b/SpecForge-ext/cache/compiled_kernels/py/cpyon4zgaupgqfwtaeshxummq5taahi4k54ubix2xgrrupxyugiq.py new file mode 100644 index 0000000000000000000000000000000000000000..a97a9257ac9c219eb4fb3cb3ec5c3f2f296b6d5d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/py/cpyon4zgaupgqfwtaeshxummq5taahi4k54ubix2xgrrupxyugiq.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/rc/crc3qjkg45bppwc6zjt3chvkhoetzafaxeolofypru5p7mguzfwr.py b/SpecForge-ext/cache/compiled_kernels/rc/crc3qjkg45bppwc6zjt3chvkhoetzafaxeolofypru5p7mguzfwr.py new file mode 100644 index 0000000000000000000000000000000000000000..04a80c95c666cb6abd9e5c0efa5beca0cc9c9420 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rc/crc3qjkg45bppwc6zjt3chvkhoetzafaxeolofypru5p7mguzfwr.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/rc/d7bf3f60aaae1a41a45bf445bbc6ca3edd6f56772ff6e8c6f864f45948d3b387.best_config b/SpecForge-ext/cache/compiled_kernels/rc/d7bf3f60aaae1a41a45bf445bbc6ca3edd6f56772ff6e8c6f864f45948d3b387.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b3edbf99a0eba8d80382a60b6e8c1e8217d859f3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rc/d7bf3f60aaae1a41a45bf445bbc6ca3edd6f56772ff6e8c6f864f45948d3b387.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "L4G4T7LJLJ5UKCYTJO6FX7X7X5CAA5AHUH7H57L6AM4BNGXXAAVQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/rh/crhp3m42ilrikc67vfludc3lvyvflmwovdjakmxhro7746x2vlmz.py b/SpecForge-ext/cache/compiled_kernels/rh/crhp3m42ilrikc67vfludc3lvyvflmwovdjakmxhro7746x2vlmz.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7327bc33480f24cf3aedbd8555256d4b2fc7e7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rh/crhp3m42ilrikc67vfludc3lvyvflmwovdjakmxhro7746x2vlmz.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vn/cvn736oi3yulmpiv2kyhznjdsmsi3u35zxuqvuyabq7sna42w72l.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream7) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/rp/crpp3nsuyiu37keoxcc4eefiyo35djuqgd2blif3lhebmqyhybl4.py b/SpecForge-ext/cache/compiled_kernels/rp/crpp3nsuyiu37keoxcc4eefiyo35djuqgd2blif3lhebmqyhybl4.py new file mode 100644 index 0000000000000000000000000000000000000000..da609716dec7092501c4098ce291b273f7e0280b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rp/crpp3nsuyiu37keoxcc4eefiyo35djuqgd2blif3lhebmqyhybl4.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/sh/cshnc2xfeqljlx2yygmlwikvuwrqaw3mjbtw2iod3tdoksj22xly.py b/SpecForge-ext/cache/compiled_kernels/sh/cshnc2xfeqljlx2yygmlwikvuwrqaw3mjbtw2iod3tdoksj22xly.py new file mode 100644 index 0000000000000000000000000000000000000000..11e91c65c8a6389c0cb71e4a9148ea49a628bd93 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/sh/cshnc2xfeqljlx2yygmlwikvuwrqaw3mjbtw2iod3tdoksj22xly.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/sv/b2093eaa973d2fd3035ffcb4984647bab53e88371fbc311a150a08927762c4bd.best_config b/SpecForge-ext/cache/compiled_kernels/sv/b2093eaa973d2fd3035ffcb4984647bab53e88371fbc311a150a08927762c4bd.best_config new file mode 100644 index 0000000000000000000000000000000000000000..8920a6ebe9dac1a267cf3c5b5085d70019ad08a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/sv/b2093eaa973d2fd3035ffcb4984647bab53e88371fbc311a150a08927762c4bd.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 36, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/u5/8b83e6eda2f1357a0520c73629786ff9cf53d7b73ea7d2f6a2cce5eaf56cc486.best_config b/SpecForge-ext/cache/compiled_kernels/u5/8b83e6eda2f1357a0520c73629786ff9cf53d7b73ea7d2f6a2cce5eaf56cc486.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/u5/8b83e6eda2f1357a0520c73629786ff9cf53d7b73ea7d2f6a2cce5eaf56cc486.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/u5/cu55o4f4khg2wuonhaoogm7cwe7beivg5otgutgnrv3xkelakvcz.py b/SpecForge-ext/cache/compiled_kernels/u5/cu55o4f4khg2wuonhaoogm7cwe7beivg5otgutgnrv3xkelakvcz.py new file mode 100644 index 0000000000000000000000000000000000000000..b809cd12bcd19eb758eddebf567fa71de8221cff --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/u5/cu55o4f4khg2wuonhaoogm7cwe7beivg5otgutgnrv3xkelakvcz.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/u5/cu5qk2aivacbrthkcnzf3am2li7s422lho6nzfcqmwax6ww2maby.py b/SpecForge-ext/cache/compiled_kernels/u5/cu5qk2aivacbrthkcnzf3am2li7s422lho6nzfcqmwax6ww2maby.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9e63849afd79375fcbd00530d7d74ab2ff473f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/u5/cu5qk2aivacbrthkcnzf3am2li7s422lho6nzfcqmwax6ww2maby.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/u5/cu5vbm4k5rx2ckzgfjj47hdlzuvwn5xanjqx3duors7zpk22vecs.py b/SpecForge-ext/cache/compiled_kernels/u5/cu5vbm4k5rx2ckzgfjj47hdlzuvwn5xanjqx3duors7zpk22vecs.py new file mode 100644 index 0000000000000000000000000000000000000000..ad945deaa5b11d329b019c4310a7a085c49b854b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/u5/cu5vbm4k5rx2ckzgfjj47hdlzuvwn5xanjqx3duors7zpk22vecs.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ub/cubczabrf2ptryq2athnmru2byipbgarioa3462puxv2jwv6vm4c.py b/SpecForge-ext/cache/compiled_kernels/ub/cubczabrf2ptryq2athnmru2byipbgarioa3462puxv2jwv6vm4c.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3b5a4196ea37b5fb3cc004893d76f9b4fe697c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ub/cubczabrf2ptryq2athnmru2byipbgarioa3462puxv2jwv6vm4c.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ub/fe085c5b3cbeab256cc820d5b6ed03a818924320b34066557efacfa4d9dd4cf3.best_config b/SpecForge-ext/cache/compiled_kernels/ub/fe085c5b3cbeab256cc820d5b6ed03a818924320b34066557efacfa4d9dd4cf3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..266f460278e85c345185451f023083ef4f3937ee --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ub/fe085c5b3cbeab256cc820d5b6ed03a818924320b34066557efacfa4d9dd4cf3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ue/cueey4eyo3vezj23w73bspqk3hwft3v6gzwq43jxzs4bcvkqnk3y.py b/SpecForge-ext/cache/compiled_kernels/ue/cueey4eyo3vezj23w73bspqk3hwft3v6gzwq43jxzs4bcvkqnk3y.py new file mode 100644 index 0000000000000000000000000000000000000000..93621af1a8e9a2893ed0491ac435482ea7e91ec8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ue/cueey4eyo3vezj23w73bspqk3hwft3v6gzwq43jxzs4bcvkqnk3y.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ue/cueqmeuvo5ppmfcn2jymnmdqmbzmj7asfrfy3qh5zrpmfg7bghek.py b/SpecForge-ext/cache/compiled_kernels/ue/cueqmeuvo5ppmfcn2jymnmdqmbzmj7asfrfy3qh5zrpmfg7bghek.py new file mode 100644 index 0000000000000000000000000000000000000000..971e9dd2cf0a3d230a268ff8db33e1b847480a43 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ue/cueqmeuvo5ppmfcn2jymnmdqmbzmj7asfrfy3qh5zrpmfg7bghek.py @@ -0,0 +1,352 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ij/cijk7uwjezxqwmpqndb2m6jmplqondyz4xw5ywxgoarpxly5nkmr.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gi/cgim5omgpzmqctohwn6fzcyz4k522sj2zt2nk2nh2j3m4l73x44y.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uw/cuw54qt3adv2mqfssnwpmkprl2vhbjhxprmh3bb5537pl6evv4j7.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:1" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:1" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:1" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3m/c3mfnz3jqpdzlott45yvd2kki53nhik366siiuob2jitdkwx6tyg.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg6_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:1" = PlaceHolder[target=arg6_1] +# %sum_1 : Tensor "i64[][]cuda:1" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:1" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream1 = get_raw_stream(1) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream1) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream1 = get_raw_stream(1) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream1) + del arg3_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream1 = get_raw_stream(1) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream1) + del arg4_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream1 = get_raw_stream(1) + triton_red_fused_clamp_min_div_sum_3.run(arg6_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream1) + del arg6_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1041 + arg1_1 = rand_strided((2, 1041, 32000), (33312000, 32000, 1), device='cuda:1', dtype=torch.bfloat16) + arg2_1 = 33536000 + arg3_1 = rand_strided((2, 1041, 32000), (33536000, 32000, 1), device='cuda:1', dtype=torch.float32) + arg4_1 = rand_strided((2, 1041, 1), (1041, 1, 1), device='cuda:1', dtype=torch.int64) + arg5_1 = 1041 + arg6_1 = rand_strided((2, 1041, 1), (1041, 1, 1), device='cuda:1', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ue/d87e7669b34a5476bfdd014272dcdf665ec816f0fbecb65ec0efdcb243763621.best_config b/SpecForge-ext/cache/compiled_kernels/ue/d87e7669b34a5476bfdd014272dcdf665ec816f0fbecb65ec0efdcb243763621.best_config new file mode 100644 index 0000000000000000000000000000000000000000..758b6d2bc873793fbcdbae9195bd9dbe5fd4de6d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ue/d87e7669b34a5476bfdd014272dcdf665ec816f0fbecb65ec0efdcb243763621.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 66, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/uj/cujounzyjqbo2f4xgxaf65o4jlm636lpvcp3rve5m4sfbwcoltf7.py b/SpecForge-ext/cache/compiled_kernels/uj/cujounzyjqbo2f4xgxaf65o4jlm636lpvcp3rve5m4sfbwcoltf7.py new file mode 100644 index 0000000000000000000000000000000000000000..6c577b58cf2395725b6ff198082158b150eb6669 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uj/cujounzyjqbo2f4xgxaf65o4jlm636lpvcp3rve5m4sfbwcoltf7.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) diff --git a/SpecForge-ext/cache/compiled_kernels/um/cumyi5vygpm43r5iray76spbvrjf35zwquvszjyymkb263yffw5p.py b/SpecForge-ext/cache/compiled_kernels/um/cumyi5vygpm43r5iray76spbvrjf35zwquvszjyymkb263yffw5p.py new file mode 100644 index 0000000000000000000000000000000000000000..08ee9286f9a9b49af2b3bc41682a0e90ae503347 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/um/cumyi5vygpm43r5iray76spbvrjf35zwquvszjyymkb263yffw5p.py @@ -0,0 +1,352 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/x7/cx7fsejzde6zv22nl7w3xpjhybajijgeetsfqi733ibymkptkdrq.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/af/cafe3dsuelcloemwu5jdikp7lqano5qxv7iayhtm5xgji2xvr4k6.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:3" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2o/c2otr5mtbf3tmh4ztmfjn6qv6r3raha22m4sr5h4kaplsk53xtg4.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:3" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:3" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:3" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cj/ccjx73kwqy3z57a3fjxor5ma5tgytixf7htmrtqxzyfleohcklv4.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg6_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:3" = PlaceHolder[target=arg6_1] +# %sum_1 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream3 = get_raw_stream(3) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream3) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream3 = get_raw_stream(3) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream3) + del arg3_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream3 = get_raw_stream(3) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream3) + del arg4_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream3 = get_raw_stream(3) + triton_red_fused_clamp_min_div_sum_3.run(arg6_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream3) + del arg6_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2014 + arg1_1 = rand_strided((2, 2014, 32000), (64448000, 32000, 1), device='cuda:3', dtype=torch.bfloat16) + arg2_1 = 64672000 + arg3_1 = rand_strided((2, 2014, 32000), (64672000, 32000, 1), device='cuda:3', dtype=torch.float32) + arg4_1 = rand_strided((2, 2014, 1), (2014, 1, 1), device='cuda:3', dtype=torch.int64) + arg5_1 = 2014 + arg6_1 = rand_strided((2, 2014, 1), (2014, 1, 1), device='cuda:3', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ur/curt662mnb3uk6niziki36t6ptmy6aozj55ereltl4nydoraw2av.py b/SpecForge-ext/cache/compiled_kernels/ur/curt662mnb3uk6niziki36t6ptmy6aozj55ereltl4nydoraw2av.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbe8064cebbddecbed365f5c203288d693c967e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ur/curt662mnb3uk6niziki36t6ptmy6aozj55ereltl4nydoraw2av.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/uu/cuug6lebxsr4vzmtgo3d4szmyci3uml2hmuljap6ewwayzcgizwe.py b/SpecForge-ext/cache/compiled_kernels/uu/cuug6lebxsr4vzmtgo3d4szmyci3uml2hmuljap6ewwayzcgizwe.py new file mode 100644 index 0000000000000000000000000000000000000000..0152483d43d7fb6a2ef16deca7b75f1f2cd033a7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uu/cuug6lebxsr4vzmtgo3d4szmyci3uml2hmuljap6ewwayzcgizwe.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/uz/17e1a289d601ffe2da457d37b58a6d5cbb157e9949975b4683d1d34a96324490.best_config b/SpecForge-ext/cache/compiled_kernels/uz/17e1a289d601ffe2da457d37b58a6d5cbb157e9949975b4683d1d34a96324490.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f1c46524ae475f95a41419c3265ac06e5e818e68 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uz/17e1a289d601ffe2da457d37b58a6d5cbb157e9949975b4683d1d34a96324490.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 61, "triton_cache_hash": "EGDJYO36DUYGK3UQBUH6S7RMVKF77GGHWVMFFZR5R4TDMIZ4YVJA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/uz/cuzaa72sq74zgpghcgqyz53mibfpseeiihyuhbv6aqdkhvzgtoe2.py b/SpecForge-ext/cache/compiled_kernels/uz/cuzaa72sq74zgpghcgqyz53mibfpseeiihyuhbv6aqdkhvzgtoe2.py new file mode 100644 index 0000000000000000000000000000000000000000..0152fc54ab4e00fb8de2a9794bb95d526de71850 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uz/cuzaa72sq74zgpghcgqyz53mibfpseeiihyuhbv6aqdkhvzgtoe2.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py b/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py new file mode 100644 index 0000000000000000000000000000000000000000..24d7313921944cdd90e91c600942aecb1ef321b4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/vm/cvm76ktjawumsjmqawlhlik5wbhx6hriz5cgra754gsrszvlpuuu.py b/SpecForge-ext/cache/compiled_kernels/vm/cvm76ktjawumsjmqawlhlik5wbhx6hriz5cgra754gsrszvlpuuu.py new file mode 100644 index 0000000000000000000000000000000000000000..47df46cf8930793d21d7335c3e54a3d3f94c23b4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vm/cvm76ktjawumsjmqawlhlik5wbhx6hriz5cgra754gsrszvlpuuu.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/vm/cvmrbypqporyupnmgxkxj7m4wlj6n3nml7s3nfg7wbcy6munhxc4.py b/SpecForge-ext/cache/compiled_kernels/vm/cvmrbypqporyupnmgxkxj7m4wlj6n3nml7s3nfg7wbcy6munhxc4.py new file mode 100644 index 0000000000000000000000000000000000000000..7acaf109a62cfbe1d41829a72cc9a542c65962ae --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vm/cvmrbypqporyupnmgxkxj7m4wlj6n3nml7s3nfg7wbcy6munhxc4.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/vp/cvpf44p6d56ladmrqf7fvuche2mzrfe4zaaok6pmgcub4rrhj5rm.py b/SpecForge-ext/cache/compiled_kernels/vp/cvpf44p6d56ladmrqf7fvuche2mzrfe4zaaok6pmgcub4rrhj5rm.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9ac23b4a2c3c8749e61dbeef153c5b25fdb96c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vp/cvpf44p6d56ladmrqf7fvuche2mzrfe4zaaok6pmgcub4rrhj5rm.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/vp/cvpm4f4ow3q465hirisdhyfdqqp42fbrfdmpbc3zwakmuvuqzqlw.py b/SpecForge-ext/cache/compiled_kernels/vp/cvpm4f4ow3q465hirisdhyfdqqp42fbrfdmpbc3zwakmuvuqzqlw.py new file mode 100644 index 0000000000000000000000000000000000000000..fbca5ab24e161260d20992d3d467d69428630013 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vp/cvpm4f4ow3q465hirisdhyfdqqp42fbrfdmpbc3zwakmuvuqzqlw.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/iq/ciq6pd6iqdrvslwjkevfrajcgblzxkn3jrpbf4qlytcsuwipr2ew.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:0"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sj/csjdg6h6lsx2eqcvfamcrwha5tqlrysjhxxplgj5mowd6shnhhwc.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:0" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:0" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:0" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:0" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:0" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_11 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ce/ccefvxdobtenaocabewevaf45h6p7aehp54xd7ccf3732l3vxvwg.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:0" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:0" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:0" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream0 = get_raw_stream(0) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream0) + del arg0_1 + buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream0 = get_raw_stream(0) + triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream0) + buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream0 = get_raw_stream(0) + triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream0) + buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream0 = get_raw_stream(0) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream0) + del buf0 + buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream0 = get_raw_stream(0) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream0) + del buf8 + buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream0 = get_raw_stream(0) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream0) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/vp/fc5c5a3901b589a8125f21383b9577399e45391d205f4cdefab44647a36cdd23.best_config b/SpecForge-ext/cache/compiled_kernels/vp/fc5c5a3901b589a8125f21383b9577399e45391d205f4cdefab44647a36cdd23.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c1c51c5048e176f0cf0b0d2646bd98c4186a3cba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vp/fc5c5a3901b589a8125f21383b9577399e45391d205f4cdefab44647a36cdd23.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "EGDJYO36DUYGK3UQBUH6S7RMVKF77GGHWVMFFZR5R4TDMIZ4YVJA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/vz/cvzofvv5xx3zbd3qsg6ytmxqt6aoybtka6jzw6fa2ybplg6reklt.py b/SpecForge-ext/cache/compiled_kernels/vz/cvzofvv5xx3zbd3qsg6ytmxqt6aoybtka6jzw6fa2ybplg6reklt.py new file mode 100644 index 0000000000000000000000000000000000000000..058c7510c67ac39579effa4e230462f7585485ed --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vz/cvzofvv5xx3zbd3qsg6ytmxqt6aoybtka6jzw6fa2ybplg6reklt.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/w7/2e1637f37ee9c7309be22551b810f3130488ab4e65e5be4b666ffc744bb8f913.best_config b/SpecForge-ext/cache/compiled_kernels/w7/2e1637f37ee9c7309be22551b810f3130488ab4e65e5be4b666ffc744bb8f913.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbafec32134c55e06add8fdbae259cebe3543 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/w7/2e1637f37ee9c7309be22551b810f3130488ab4e65e5be4b666ffc744bb8f913.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/w7/cw7dpbzsfanlem7ovndaqtk32qivdbyixvhqs7vmjcsetp7fyayh.py b/SpecForge-ext/cache/compiled_kernels/w7/cw7dpbzsfanlem7ovndaqtk32qivdbyixvhqs7vmjcsetp7fyayh.py new file mode 100644 index 0000000000000000000000000000000000000000..c202dd21bd7f7d2db9287c1f1df73599c84b4496 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/w7/cw7dpbzsfanlem7ovndaqtk32qivdbyixvhqs7vmjcsetp7fyayh.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/w7/cw7oln47tlgyzufj4xf52hdvopoe4spnpx37ofwjpjfjn6kmdjvi.py b/SpecForge-ext/cache/compiled_kernels/w7/cw7oln47tlgyzufj4xf52hdvopoe4spnpx37ofwjpjfjn6kmdjvi.py new file mode 100644 index 0000000000000000000000000000000000000000..e98455303cdb443e3a740e7e30240fb43d7c0430 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/w7/cw7oln47tlgyzufj4xf52hdvopoe4spnpx37ofwjpjfjn6kmdjvi.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bl/cbl5w3gbsoiwosnwknfarb55lklpfntpuz6q4jjui35yi2wdwepo.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/js/cjsua5bk75g3qrlvlcodjiakenr5rs5kkcshkbp7bxk6gmv42cax.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[8][1]cuda:7" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream7) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 8, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/w7/cw7u6rgkoihwizuge644mng2ddg3bqiixfxegsmxugtblafc6kz3.py b/SpecForge-ext/cache/compiled_kernels/w7/cw7u6rgkoihwizuge644mng2ddg3bqiixfxegsmxugtblafc6kz3.py new file mode 100644 index 0000000000000000000000000000000000000000..d3124ef6fda4defcb77fd5198d945184fdbfb82b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/w7/cw7u6rgkoihwizuge644mng2ddg3bqiixfxegsmxugtblafc6kz3.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/xg/cxgas2jt2e7frwpm3nd5h7wlzdo2fb2yonkvkfoarpyafiwosg23.py b/SpecForge-ext/cache/compiled_kernels/xg/cxgas2jt2e7frwpm3nd5h7wlzdo2fb2yonkvkfoarpyafiwosg23.py new file mode 100644 index 0000000000000000000000000000000000000000..99dd69ff9bdc79cf03c4dcda61cd796d1a515f42 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xg/cxgas2jt2e7frwpm3nd5h7wlzdo2fb2yonkvkfoarpyafiwosg23.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/xg/cxgourummzwsux6r2gxe7ifvqpdhpgvgbs36tkitfwpr24b4gcvt.py b/SpecForge-ext/cache/compiled_kernels/xg/cxgourummzwsux6r2gxe7ifvqpdhpgvgbs36tkitfwpr24b4gcvt.py new file mode 100644 index 0000000000000000000000000000000000000000..f40336ea0b497c667cc8439c3ddf33ebead32e4d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xg/cxgourummzwsux6r2gxe7ifvqpdhpgvgbs36tkitfwpr24b4gcvt.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/xo/cxogqdjmha6a2mjw43q2sd56tohi4ncj3zpekedkws2baol4oehq.py b/SpecForge-ext/cache/compiled_kernels/xo/cxogqdjmha6a2mjw43q2sd56tohi4ncj3zpekedkws2baol4oehq.py new file mode 100644 index 0000000000000000000000000000000000000000..f057caf21207c0c701e048b9a3cc6fdd6aa29345 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xo/cxogqdjmha6a2mjw43q2sd56tohi4ncj3zpekedkws2baol4oehq.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/xr/cxrs4oh2cbzbv4dn33jgbva7pkoqpa4buz6xfi73qi3l56d4nsbd.py b/SpecForge-ext/cache/compiled_kernels/xr/cxrs4oh2cbzbv4dn33jgbva7pkoqpa4buz6xfi73qi3l56d4nsbd.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd6999b978e6f44a4248a7d938da192eb337f78 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xr/cxrs4oh2cbzbv4dn33jgbva7pkoqpa4buz6xfi73qi3l56d4nsbd.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/xz/cxzhauqdy36cnyz3lpbu7kkki6oos2euxi2mut56lp5d6vd6i4z2.py b/SpecForge-ext/cache/compiled_kernels/xz/cxzhauqdy36cnyz3lpbu7kkki6oos2euxi2mut56lp5d6vd6i4z2.py new file mode 100644 index 0000000000000000000000000000000000000000..a144d8d25fa19a736fb4202af3316ed390414508 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xz/cxzhauqdy36cnyz3lpbu7kkki6oos2euxi2mut56lp5d6vd6i4z2.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/yz/cyzv2y6nd2vqpgbprk3dyfccmr5pih6uvdo7rya5kdf7iczgzzap.py b/SpecForge-ext/cache/compiled_kernels/yz/cyzv2y6nd2vqpgbprk3dyfccmr5pih6uvdo7rya5kdf7iczgzzap.py new file mode 100644 index 0000000000000000000000000000000000000000..75f02ce3d4bd2968029b4cae76947c9e1e2b32e6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yz/cyzv2y6nd2vqpgbprk3dyfccmr5pih6uvdo7rya5kdf7iczgzzap.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vp/cvpf44p6d56ladmrqf7fvuche2mzrfe4zaaok6pmgcub4rrhj5rm.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fw/cfwnjef76r37zlemz2nqxs6g7h5pbg3xlsgem26kfuesgumqqznd.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:3" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[2, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:3" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[2, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:3" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream3) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 2, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32) + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32) + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32) + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32) + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/z3/cz3pr3u2ziz6bgnei3ljbppkwjdq4m4h4tz7tqzbwrrt3atkbqxw.py b/SpecForge-ext/cache/compiled_kernels/z3/cz3pr3u2ziz6bgnei3ljbppkwjdq4m4h4tz7tqzbwrrt3atkbqxw.py new file mode 100644 index 0000000000000000000000000000000000000000..120893e26588728f0d65d842d5c9861c649535e9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/z3/cz3pr3u2ziz6bgnei3ljbppkwjdq4m4h4tz7tqzbwrrt3atkbqxw.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4q/c4qgb56izzxslbbeos7excinlzn4gyacgk7ghnm4jtlao6kr7gyd.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream2) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/za/czavugmlxltvuliq4dhv6chq23pigmyj6dktmupor2jiyiklfef6.py b/SpecForge-ext/cache/compiled_kernels/za/czavugmlxltvuliq4dhv6chq23pigmyj6dktmupor2jiyiklfef6.py new file mode 100644 index 0000000000000000000000000000000000000000..b15929322360f1a46661593ae76f90c35ce83e1f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/za/czavugmlxltvuliq4dhv6chq23pigmyj6dktmupor2jiyiklfef6.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zf/czfjpc2bmajr63w5p55xld3zlnib2cxdzyx3n2uan6r324ujkv52.py b/SpecForge-ext/cache/compiled_kernels/zf/czfjpc2bmajr63w5p55xld3zlnib2cxdzyx3n2uan6r324ujkv52.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b1e4d25997fd29010e686e846b9b704c05ebad --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zf/czfjpc2bmajr63w5p55xld3zlnib2cxdzyx3n2uan6r324ujkv52.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/zf/czfnbhok6wj37txssq5tanx3rrseuguo43buuqf3rrut6max4ivs.py b/SpecForge-ext/cache/compiled_kernels/zf/czfnbhok6wj37txssq5tanx3rrseuguo43buuqf3rrut6max4ivs.py new file mode 100644 index 0000000000000000000000000000000000000000..3703da49a47b02e9fc8c65af5f6725cbbd005967 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zf/czfnbhok6wj37txssq5tanx3rrseuguo43buuqf3rrut6max4ivs.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/zn/cznl7iprs6w6hjf3v7kpinw74h6ritdb2cyiauvwzci53zrocsyc.py b/SpecForge-ext/cache/compiled_kernels/zn/cznl7iprs6w6hjf3v7kpinw74h6ritdb2cyiauvwzci53zrocsyc.py new file mode 100644 index 0000000000000000000000000000000000000000..7967f6d260de413a004bc34c0309b530e83ce15f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zn/cznl7iprs6w6hjf3v7kpinw74h6ritdb2cyiauvwzci53zrocsyc.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file